Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	release
Browse files- .gitattributes +6 -0
- README.md +32 -3
- app.py +232 -0
- generator.py +178 -0
- models.py +196 -0
- prompts/conversational_a.wav +3 -0
- prompts/conversational_b.wav +3 -0
- prompts/read_speech_a.wav +3 -0
- prompts/read_speech_b.wav +3 -0
- prompts/read_speech_c.wav +3 -0
- prompts/read_speech_d.wav +3 -0
- requirements.txt +11 -0
- watermarking.py +78 -0
    	
        .gitattributes
    CHANGED
    
    | @@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text | |
| 33 | 
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 34 | 
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         | 
| 35 | 
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 33 | 
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 34 | 
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         | 
| 35 | 
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
| 36 | 
            +
            prompts/conversational_a.wav filter=lfs diff=lfs merge=lfs -text
         | 
| 37 | 
            +
            prompts/conversational_b.wav filter=lfs diff=lfs merge=lfs -text
         | 
| 38 | 
            +
            prompts/read_speech_a.wav filter=lfs diff=lfs merge=lfs -text
         | 
| 39 | 
            +
            prompts/read_speech_b.wav filter=lfs diff=lfs merge=lfs -text
         | 
| 40 | 
            +
            prompts/read_speech_c.wav filter=lfs diff=lfs merge=lfs -text
         | 
| 41 | 
            +
            prompts/read_speech_d.wav filter=lfs diff=lfs merge=lfs -text
         | 
    	
        README.md
    CHANGED
    
    | @@ -1,5 +1,5 @@ | |
| 1 | 
             
            ---
         | 
| 2 | 
            -
            title: CSM  | 
| 3 | 
             
            emoji: 🚀
         | 
| 4 | 
             
            colorFrom: blue
         | 
| 5 | 
             
            colorTo: blue
         | 
| @@ -8,7 +8,36 @@ sdk_version: 5.20.0 | |
| 8 | 
             
            app_file: app.py
         | 
| 9 | 
             
            pinned: false
         | 
| 10 | 
             
            license: apache-2.0
         | 
| 11 | 
            -
            short_description:  | 
| 12 | 
             
            ---
         | 
| 13 |  | 
| 14 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
             
            ---
         | 
| 2 | 
            +
            title: Sesame CSM Space
         | 
| 3 | 
             
            emoji: 🚀
         | 
| 4 | 
             
            colorFrom: blue
         | 
| 5 | 
             
            colorTo: blue
         | 
|  | |
| 8 | 
             
            app_file: app.py
         | 
| 9 | 
             
            pinned: false
         | 
| 10 | 
             
            license: apache-2.0
         | 
| 11 | 
            +
            short_description: Generation using Sesame's Conversational Speech Model
         | 
| 12 | 
             
            ---
         | 
| 13 |  | 
| 14 | 
            +
            ## CSM 1B
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            **2025/03/13** - We are releasing the 1B CSM variant. Code is available on GitHub: [SesameAILabs/csm](https://github.com/SesameAILabs/csm). Checkpoint is [hosted on HuggingFace](https://huggingface.co/sesame/csm-1b).
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            Try out the interactive demo of our fine-tuned version [sesame.com/voicedemo](https://www.sesame.com/voicedemo).
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            Generate from the open-source base model [hosted on HuggingFace](https://huggingface.co/spaces/sesame/csm-1b).
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            ---
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            CSM (Conversational Speech Model) is a speech generation model from [Sesame](sesame.com) that generates RVQ audio codes from text and audio inputs. A fine-tuned version of this model powers the interactive demo in our [technical blog post](https://www.sesame.com/research/crossing_the_uncanny_valley_of_voice).
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            The model architecture employs a [Llama](https://www.llama.com/) backbone and a smaller audio decoder that produces [Mimi](https://huggingface.co/kyutai/mimi) audio codes.
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            ## Misuse and abuse ⚠️
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            This project provides a high-quality speech generation model for research and educational purposes. While we encourage responsible and ethical use, we **explicitly prohibit** the following:
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            - **Impersonation or Fraud**: Do not use this model to generate speech that mimics real individuals without their explicit consent.
         | 
| 33 | 
            +
            - **Misinformation or Deception**: Do not use this model to create deceptive or misleading content, such as fake news or fraudulent calls.
         | 
| 34 | 
            +
            - **Illegal or Harmful Activities**: Do not use this model for any illegal, harmful, or malicious purposes.
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            By using this model, you agree to comply with all applicable laws and ethical guidelines. We are **not responsible** for any misuse, and we strongly condemn unethical applications of this technology.
         | 
| 37 | 
            +
             | 
| 38 | 
            +
            **Prompts**
         | 
| 39 | 
            +
            Conversational prompts are from the [EdAcc dataset](https://groups.inf.ed.ac.uk/edacc/)
         | 
| 40 | 
            +
            Read speech prompts are form the [LibriTTS-R dataset](https://google.github.io/df-conformer/librittsr/)
         | 
| 41 | 
            +
             | 
| 42 | 
            +
            **Authors**
         | 
| 43 | 
            +
            Johan Schalkwyk, Ankit Kumar, Dan Lyth, Sefik Emre Eskimez, Zack Hodari, Cinjon Resnick, Ramon Sanabria, Raven Jiang, and the Sesame team.
         | 
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,232 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import gradio as gr
         | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
            import spaces
         | 
| 6 | 
            +
            import torch
         | 
| 7 | 
            +
            import torchaudio
         | 
| 8 | 
            +
            from generator import Segment, load_csm_1b
         | 
| 9 | 
            +
            from huggingface_hub import hf_hub_download, login
         | 
| 10 | 
            +
            from watermarking import watermark
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            api_key = os.getenv("HF_TOKEN")
         | 
| 13 | 
            +
            gpu_timeout = int(os.getenv("GPU_TIMEOUT", 60))
         | 
| 14 | 
            +
            CSM_1B_HF_WATERMARK = list(map(int, os.getenv("WATERMARK_KEY").split(" ")))
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            login(token=api_key)
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            SPACE_INTRO_TEXT = """\
         | 
| 19 | 
            +
            # Sesame CSM 1B
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            Generate from CSM 1B (Conversational Speech Model). 
         | 
| 22 | 
            +
            Code is available on GitHub: [SesameAILabs/csm](https://github.com/SesameAILabs/csm). 
         | 
| 23 | 
            +
            Checkpoint is [hosted on HuggingFace](https://huggingface.co/sesame/csm-1b).
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            Try out the interactive demo of our fine-tuned model [sesame.com/voicedemo](https://www.sesame.com/voicedemo).
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            The model has some capacity for non-English languages due to data contamination in the training 
         | 
| 28 | 
            +
            data, but it is likely not to perform well.
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            ---
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            """
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            CONVO_INTRO_TEXT = """\
         | 
| 35 | 
            +
            ## Conversation content
         | 
| 36 | 
            +
             | 
| 37 | 
            +
            Each line is an utterance in the conversation to generate. Speakers alternate between A and B, starting with speaker A.
         | 
| 38 | 
            +
            """
         | 
| 39 | 
            +
             | 
| 40 | 
            +
            DEFAULT_CONVERSATION = """\
         | 
| 41 | 
            +
            Hey how are you doing.
         | 
| 42 | 
            +
            Pretty good, pretty good.
         | 
| 43 | 
            +
            I'm great, so happy to be speaking to you.
         | 
| 44 | 
            +
            Me too, this is some cool stuff huh?
         | 
| 45 | 
            +
            Yeah, I've been reading more about speech generation, and it really seems like context is important.
         | 
| 46 | 
            +
            Definitely.
         | 
| 47 | 
            +
            """
         | 
| 48 | 
            +
             | 
| 49 | 
            +
            SPEAKER_PROMPTS = {
         | 
| 50 | 
            +
                "conversational_a": {
         | 
| 51 | 
            +
                    "text": (
         | 
| 52 | 
            +
                        "like revising for an exam I'd have to try and like keep up the momentum because I'd "
         | 
| 53 | 
            +
                        "start really early I'd be like okay I'm gonna start revising now and then like "
         | 
| 54 | 
            +
                        "you're revising for ages and then I just like start losing steam I didn't do that "
         | 
| 55 | 
            +
                        "for the exam we had recently to be fair that was a more of a last minute scenario "
         | 
| 56 | 
            +
                        "but like yeah I'm trying to like yeah I noticed this yesterday that like Mondays I "
         | 
| 57 | 
            +
                        "sort of start the day with this not like a panic but like a"
         | 
| 58 | 
            +
                    ),
         | 
| 59 | 
            +
                    "audio": "prompts/conversational_a.wav",
         | 
| 60 | 
            +
                },
         | 
| 61 | 
            +
                "conversational_b": {
         | 
| 62 | 
            +
                    "text": (
         | 
| 63 | 
            +
                        "like a super Mario level. Like it's very like high detail. And like, once you get "
         | 
| 64 | 
            +
                        "into the park, it just like, everything looks like a computer game and they have all "
         | 
| 65 | 
            +
                        "these, like, you know, if, if there's like a, you know, like in a Mario game, they "
         | 
| 66 | 
            +
                        "will have like a question block. And if you like, you know, punch it, a coin will "
         | 
| 67 | 
            +
                        "come out. So like everyone, when they come into the park, they get like this little "
         | 
| 68 | 
            +
                        "bracelet and then you can go punching question blocks around."
         | 
| 69 | 
            +
                    ),
         | 
| 70 | 
            +
                    "audio": "prompts/conversational_b.wav",
         | 
| 71 | 
            +
                },
         | 
| 72 | 
            +
                "read_speech_a": {
         | 
| 73 | 
            +
                    "text": (
         | 
| 74 | 
            +
                        "And Lake turned round upon me, a little abruptly, his odd yellowish eyes, a little "
         | 
| 75 | 
            +
                        "like those of the sea eagle, and the ghost of his smile that flickered on his "
         | 
| 76 | 
            +
                        "singularly pale face, with a stern and insidious look, confronted me."
         | 
| 77 | 
            +
                    ),
         | 
| 78 | 
            +
                    "audio": "prompts/read_speech_a.wav",
         | 
| 79 | 
            +
                },
         | 
| 80 | 
            +
                "read_speech_b": {
         | 
| 81 | 
            +
                    "text": (
         | 
| 82 | 
            +
                        "He was such a big boy that he wore high boots and carried a jack knife. He gazed and "
         | 
| 83 | 
            +
                        "gazed at the cap, and could not keep from fingering the blue tassel."
         | 
| 84 | 
            +
                    ),
         | 
| 85 | 
            +
                    "audio": "prompts/read_speech_b.wav",
         | 
| 86 | 
            +
                },
         | 
| 87 | 
            +
                "read_speech_c": {
         | 
| 88 | 
            +
                    "text": (
         | 
| 89 | 
            +
                        "All passed so quickly, there was so much going on around him, the Tree quite forgot "
         | 
| 90 | 
            +
                        "to look to himself."
         | 
| 91 | 
            +
                    ),
         | 
| 92 | 
            +
                    "audio": "prompts/read_speech_c.wav",
         | 
| 93 | 
            +
                },
         | 
| 94 | 
            +
                "read_speech_d": {
         | 
| 95 | 
            +
                    "text": (
         | 
| 96 | 
            +
                        "Suddenly I was back in the old days Before you felt we ought to drift apart. It was "
         | 
| 97 | 
            +
                        "some trick-the way your eyebrows raise."
         | 
| 98 | 
            +
                    ),
         | 
| 99 | 
            +
                    "audio": "prompts/read_speech_d.wav",
         | 
| 100 | 
            +
                },
         | 
| 101 | 
            +
            }
         | 
| 102 | 
            +
             | 
| 103 | 
            +
            device = "cuda" if torch.cuda.is_available() else "cpu"
         | 
| 104 | 
            +
            model_path = hf_hub_download(repo_id="sesame/csm-1b", filename="ckpt.pt")
         | 
| 105 | 
            +
            generator = load_csm_1b(model_path, device)
         | 
| 106 | 
            +
             | 
| 107 | 
            +
             | 
| 108 | 
            +
            @spaces.GPU(duration=gpu_timeout)
         | 
| 109 | 
            +
            def infer(
         | 
| 110 | 
            +
                text_prompt_speaker_a,
         | 
| 111 | 
            +
                text_prompt_speaker_b,
         | 
| 112 | 
            +
                audio_prompt_speaker_a,
         | 
| 113 | 
            +
                audio_prompt_speaker_b,
         | 
| 114 | 
            +
                gen_conversation_input,
         | 
| 115 | 
            +
            ) -> tuple[np.ndarray, int]:
         | 
| 116 | 
            +
                audio_prompt_a = prepare_prompt(text_prompt_speaker_a, 0, audio_prompt_speaker_a)
         | 
| 117 | 
            +
                audio_prompt_b = prepare_prompt(text_prompt_speaker_b, 1, audio_prompt_speaker_b)
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                prompt_segments: list[Segment] = [audio_prompt_a, audio_prompt_b]
         | 
| 120 | 
            +
                generated_segments: list[Segment] = []
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                conversation_lines = [line.strip() for line in gen_conversation_input.strip().split("\n") if line.strip()]
         | 
| 123 | 
            +
                for i, line in enumerate(conversation_lines):
         | 
| 124 | 
            +
                    # Alternating speakers A and B, starting with A
         | 
| 125 | 
            +
                    speaker_id = i % 2
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                    audio_tensor = generator.generate(
         | 
| 128 | 
            +
                        text=line,
         | 
| 129 | 
            +
                        speaker=speaker_id,
         | 
| 130 | 
            +
                        context=prompt_segments + generated_segments,
         | 
| 131 | 
            +
                    )
         | 
| 132 | 
            +
                    generated_segments.append(Segment(text=line, speaker=speaker_id, audio=audio_tensor))
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                # Concatenate all generations and convert to 16-bit int format
         | 
| 135 | 
            +
                audio_tensors = [segment.audio for segment in generated_segments]
         | 
| 136 | 
            +
                audio_tensor = torch.cat(audio_tensors, dim=0)
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                # This applies an imperceptible watermark to identify audio as AI-generated.
         | 
| 139 | 
            +
                # Watermarking ensures transparency, dissuades misuse, and enables traceability.
         | 
| 140 | 
            +
                # Please be a responsible AI citizen and keep the watermarking in place.
         | 
| 141 | 
            +
                # If using CSM 1B in another application, use your own private key and keep it secret.
         | 
| 142 | 
            +
                audio_tensor, wm_sample_rate = watermark(
         | 
| 143 | 
            +
                    generator._watermarker, audio_tensor, generator.sample_rate, CSM_1B_HF_WATERMARK
         | 
| 144 | 
            +
                )
         | 
| 145 | 
            +
                audio_tensor = torchaudio.functional.resample(
         | 
| 146 | 
            +
                    audio_tensor, orig_freq=wm_sample_rate, new_freq=generator.sample_rate
         | 
| 147 | 
            +
                )
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                audio_array = (audio_tensor * 32768).to(torch.int16).cpu().numpy()
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                return generator.sample_rate, audio_array
         | 
| 152 | 
            +
             | 
| 153 | 
            +
             | 
| 154 | 
            +
            def prepare_prompt(text: str, speaker: int, audio_path: str) -> Segment:
         | 
| 155 | 
            +
                audio_tensor, _ = load_prompt_audio(audio_path)
         | 
| 156 | 
            +
                return Segment(text=text, speaker=speaker, audio=audio_tensor)
         | 
| 157 | 
            +
             | 
| 158 | 
            +
             | 
| 159 | 
            +
            def load_prompt_audio(audio_path: str) -> torch.Tensor:
         | 
| 160 | 
            +
                audio_tensor, sample_rate = torchaudio.load(audio_path)
         | 
| 161 | 
            +
                audio_tensor = audio_tensor.squeeze(0)
         | 
| 162 | 
            +
                if sample_rate != generator.sample_rate:
         | 
| 163 | 
            +
                    audio_tensor = torchaudio.functional.resample(
         | 
| 164 | 
            +
                        audio_tensor, orig_freq=sample_rate, new_freq=generator.sample_rate
         | 
| 165 | 
            +
                    )
         | 
| 166 | 
            +
                return audio_tensor, generator.sample_rate
         | 
| 167 | 
            +
             | 
| 168 | 
            +
             | 
| 169 | 
            +
            def create_speaker_prompt_ui(speaker_name: str):
         | 
| 170 | 
            +
                speaker_dropdown = gr.Dropdown(
         | 
| 171 | 
            +
                    choices=list(SPEAKER_PROMPTS.keys()), label="Select a predefined speaker", value=speaker_name
         | 
| 172 | 
            +
                )
         | 
| 173 | 
            +
                with gr.Accordion("Or add your own voice prompt", open=False):
         | 
| 174 | 
            +
                    text_prompt_speaker = gr.Textbox(label="Speaker prompt", lines=4, value=SPEAKER_PROMPTS[speaker_name]["text"])
         | 
| 175 | 
            +
                    audio_prompt_speaker = gr.Audio(
         | 
| 176 | 
            +
                        label="Speaker prompt", type="filepath", value=SPEAKER_PROMPTS[speaker_name]["audio"]
         | 
| 177 | 
            +
                    )
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                return speaker_dropdown, text_prompt_speaker, audio_prompt_speaker
         | 
| 180 | 
            +
             | 
| 181 | 
            +
             | 
| 182 | 
            +
            with gr.Blocks() as app:
         | 
| 183 | 
            +
                gr.Markdown(SPACE_INTRO_TEXT)
         | 
| 184 | 
            +
                gr.Markdown("## Voices")
         | 
| 185 | 
            +
                with gr.Row():
         | 
| 186 | 
            +
                    with gr.Column():
         | 
| 187 | 
            +
                        gr.Markdown("### Speaker A")
         | 
| 188 | 
            +
                        speaker_a_dropdown, text_prompt_speaker_a, audio_prompt_speaker_a = create_speaker_prompt_ui(
         | 
| 189 | 
            +
                            "conversational_a"
         | 
| 190 | 
            +
                        )
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                    with gr.Column():
         | 
| 193 | 
            +
                        gr.Markdown("### Speaker B")
         | 
| 194 | 
            +
                        speaker_b_dropdown, text_prompt_speaker_b, audio_prompt_speaker_b = create_speaker_prompt_ui(
         | 
| 195 | 
            +
                            "conversational_b"
         | 
| 196 | 
            +
                        )
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                def update_audio(speaker):
         | 
| 199 | 
            +
                    if speaker in SPEAKER_PROMPTS:
         | 
| 200 | 
            +
                        return SPEAKER_PROMPTS[speaker]["audio"]
         | 
| 201 | 
            +
                    return None
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                def update_text(speaker):
         | 
| 204 | 
            +
                    if speaker in SPEAKER_PROMPTS:
         | 
| 205 | 
            +
                        return SPEAKER_PROMPTS[speaker]["text"]
         | 
| 206 | 
            +
                    return None
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                speaker_a_dropdown.change(fn=update_audio, inputs=[speaker_a_dropdown], outputs=[audio_prompt_speaker_a])
         | 
| 209 | 
            +
                speaker_b_dropdown.change(fn=update_audio, inputs=[speaker_b_dropdown], outputs=[audio_prompt_speaker_b])
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                speaker_a_dropdown.change(fn=update_text, inputs=[speaker_a_dropdown], outputs=[text_prompt_speaker_a])
         | 
| 212 | 
            +
                speaker_b_dropdown.change(fn=update_text, inputs=[speaker_b_dropdown], outputs=[text_prompt_speaker_b])
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                gr.Markdown(CONVO_INTRO_TEXT)
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                gen_conversation_input = gr.TextArea(label="conversation", lines=20, value=DEFAULT_CONVERSATION)
         | 
| 217 | 
            +
                generate_btn = gr.Button("Generate conversation", variant="primary")
         | 
| 218 | 
            +
                audio_output = gr.Audio(label="Synthesized audio")
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                generate_btn.click(
         | 
| 221 | 
            +
                    infer,
         | 
| 222 | 
            +
                    inputs=[
         | 
| 223 | 
            +
                        text_prompt_speaker_a,
         | 
| 224 | 
            +
                        text_prompt_speaker_b,
         | 
| 225 | 
            +
                        audio_prompt_speaker_a,
         | 
| 226 | 
            +
                        audio_prompt_speaker_b,
         | 
| 227 | 
            +
                        gen_conversation_input,
         | 
| 228 | 
            +
                    ],
         | 
| 229 | 
            +
                    outputs=[audio_output],
         | 
| 230 | 
            +
                )
         | 
| 231 | 
            +
             | 
| 232 | 
            +
            app.launch(ssr_mode=True)
         | 
    	
        generator.py
    ADDED
    
    | @@ -0,0 +1,178 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            from dataclasses import dataclass
         | 
| 3 | 
            +
            from typing import List, Tuple
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
            import torchaudio
         | 
| 7 | 
            +
            from huggingface_hub import hf_hub_download
         | 
| 8 | 
            +
            from models import Model, ModelArgs
         | 
| 9 | 
            +
            from moshi.models import loaders
         | 
| 10 | 
            +
            from tokenizers.processors import TemplateProcessing
         | 
| 11 | 
            +
            from transformers import AutoTokenizer
         | 
| 12 | 
            +
            from watermarking import load_watermarker, watermark
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            CSM_1B_HF_WATERMARK = list(map(int, os.getenv("WATERMARK_KEY").split(" ")))
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            @dataclass
         | 
| 18 | 
            +
            class Segment:
         | 
| 19 | 
            +
                speaker: int
         | 
| 20 | 
            +
                text: str
         | 
| 21 | 
            +
                # (num_samples,), sample_rate = 24_000
         | 
| 22 | 
            +
                audio: torch.Tensor
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            def load_llama3_tokenizer():
         | 
| 26 | 
            +
                """
         | 
| 27 | 
            +
                https://github.com/huggingface/transformers/issues/22794#issuecomment-2092623992
         | 
| 28 | 
            +
                """
         | 
| 29 | 
            +
                tokenizer_name = "meta-llama/Llama-3.2-1B"
         | 
| 30 | 
            +
                tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
         | 
| 31 | 
            +
                bos = tokenizer.bos_token
         | 
| 32 | 
            +
                eos = tokenizer.eos_token
         | 
| 33 | 
            +
                tokenizer._tokenizer.post_processor = TemplateProcessing(
         | 
| 34 | 
            +
                    single=f"{bos}:0 $A:0 {eos}:0",
         | 
| 35 | 
            +
                    pair=f"{bos}:0 $A:0 {eos}:0 {bos}:1 $B:1 {eos}:1",
         | 
| 36 | 
            +
                    special_tokens=[(f"{bos}", tokenizer.bos_token_id), (f"{eos}", tokenizer.eos_token_id)],
         | 
| 37 | 
            +
                )
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                return tokenizer
         | 
| 40 | 
            +
             | 
| 41 | 
            +
             | 
| 42 | 
            +
            class Generator:
         | 
| 43 | 
            +
                def __init__(
         | 
| 44 | 
            +
                    self,
         | 
| 45 | 
            +
                    model: Model,
         | 
| 46 | 
            +
                ):
         | 
| 47 | 
            +
                    self._model = model
         | 
| 48 | 
            +
                    self._model.setup_caches(1)
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                    self._text_tokenizer = load_llama3_tokenizer()
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                    device = next(model.parameters()).device
         | 
| 53 | 
            +
                    mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME)
         | 
| 54 | 
            +
                    mimi = loaders.get_mimi(mimi_weight, device=device)
         | 
| 55 | 
            +
                    mimi.set_num_codebooks(32)
         | 
| 56 | 
            +
                    self._audio_tokenizer = mimi
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                    self._watermarker = load_watermarker(device=device)
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                    self.sample_rate = mimi.sample_rate
         | 
| 61 | 
            +
                    self.device = device
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                def _tokenize_text_segment(self, text: str, speaker: int) -> Tuple[torch.Tensor, torch.Tensor]:
         | 
| 64 | 
            +
                    frame_tokens = []
         | 
| 65 | 
            +
                    frame_masks = []
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                    text_tokens = self._text_tokenizer.encode(f"[{speaker}]{text}")
         | 
| 68 | 
            +
                    text_frame = torch.zeros(len(text_tokens), 33).long()
         | 
| 69 | 
            +
                    text_frame_mask = torch.zeros(len(text_tokens), 33).bool()
         | 
| 70 | 
            +
                    text_frame[:, -1] = torch.tensor(text_tokens)
         | 
| 71 | 
            +
                    text_frame_mask[:, -1] = True
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                    frame_tokens.append(text_frame.to(self.device))
         | 
| 74 | 
            +
                    frame_masks.append(text_frame_mask.to(self.device))
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                    return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0)
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                def _tokenize_audio(self, audio: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
         | 
| 79 | 
            +
                    frame_tokens = []
         | 
| 80 | 
            +
                    frame_masks = []
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                    # (K, T)
         | 
| 83 | 
            +
                    audio = audio.to(self.device)
         | 
| 84 | 
            +
                    audio_tokens = self._audio_tokenizer.encode(audio.unsqueeze(0).unsqueeze(0))[0]
         | 
| 85 | 
            +
                    # add EOS frame
         | 
| 86 | 
            +
                    eos_frame = torch.zeros(audio_tokens.size(0), 1).to(self.device)
         | 
| 87 | 
            +
                    audio_tokens = torch.cat([audio_tokens, eos_frame], dim=1)
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    audio_frame = torch.zeros(audio_tokens.size(1), 33).long().to(self.device)
         | 
| 90 | 
            +
                    audio_frame_mask = torch.zeros(audio_tokens.size(1), 33).bool().to(self.device)
         | 
| 91 | 
            +
                    audio_frame[:, :-1] = audio_tokens.transpose(0, 1)
         | 
| 92 | 
            +
                    audio_frame_mask[:, :-1] = True
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                    frame_tokens.append(audio_frame)
         | 
| 95 | 
            +
                    frame_masks.append(audio_frame_mask)
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                    return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0)
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                def _tokenize_segment(self, segment: Segment) -> Tuple[torch.Tensor, torch.Tensor]:
         | 
| 100 | 
            +
                    """
         | 
| 101 | 
            +
                    Returns:
         | 
| 102 | 
            +
                        (seq_len, 33), (seq_len, 33)
         | 
| 103 | 
            +
                    """
         | 
| 104 | 
            +
                    text_tokens, text_masks = self._tokenize_text_segment(segment.text, segment.speaker)
         | 
| 105 | 
            +
                    audio_tokens, audio_masks = self._tokenize_audio(segment.audio)
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                    return torch.cat([text_tokens, audio_tokens], dim=0), torch.cat([text_masks, audio_masks], dim=0)
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                @torch.inference_mode()
         | 
| 110 | 
            +
                def generate(
         | 
| 111 | 
            +
                    self,
         | 
| 112 | 
            +
                    text: str,
         | 
| 113 | 
            +
                    speaker: int,
         | 
| 114 | 
            +
                    context: List[Segment],
         | 
| 115 | 
            +
                    max_audio_length_ms: float = 90_000,
         | 
| 116 | 
            +
                    temperature: float = 0.9,
         | 
| 117 | 
            +
                    topk: int = 50,
         | 
| 118 | 
            +
                ) -> torch.Tensor:
         | 
| 119 | 
            +
                    self._model.reset_caches()
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                    max_audio_frames = int(max_audio_length_ms / 80)
         | 
| 122 | 
            +
                    tokens, tokens_mask = [], []
         | 
| 123 | 
            +
                    for segment in context:
         | 
| 124 | 
            +
                        segment_tokens, segment_tokens_mask = self._tokenize_segment(segment)
         | 
| 125 | 
            +
                        tokens.append(segment_tokens)
         | 
| 126 | 
            +
                        tokens_mask.append(segment_tokens_mask)
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                    gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(text, speaker)
         | 
| 129 | 
            +
                    tokens.append(gen_segment_tokens)
         | 
| 130 | 
            +
                    tokens_mask.append(gen_segment_tokens_mask)
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                    prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device)
         | 
| 133 | 
            +
                    prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device)
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                    samples = []
         | 
| 136 | 
            +
                    curr_tokens = prompt_tokens.unsqueeze(0)
         | 
| 137 | 
            +
                    curr_tokens_mask = prompt_tokens_mask.unsqueeze(0)
         | 
| 138 | 
            +
                    curr_pos = torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device)
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                    for _ in range(max_audio_frames):
         | 
| 141 | 
            +
                        sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk)
         | 
| 142 | 
            +
                        if torch.all(sample == 0):
         | 
| 143 | 
            +
                            break  # eos
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                        samples.append(sample)
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                        curr_tokens = torch.cat([sample, torch.zeros(1, 1).long().to(self.device)], dim=1).unsqueeze(1)
         | 
| 148 | 
            +
                        curr_tokens_mask = torch.cat(
         | 
| 149 | 
            +
                            [torch.ones_like(sample).bool(), torch.zeros(1, 1).bool().to(self.device)], dim=1
         | 
| 150 | 
            +
                        ).unsqueeze(1)
         | 
| 151 | 
            +
                        curr_pos = curr_pos[:, -1:] + 1
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                    audio = self._audio_tokenizer.decode(torch.stack(samples).permute(1, 2, 0)).squeeze(0).squeeze(0)
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                    # This applies an imperceptible watermark to identify audio as AI-generated.
         | 
| 156 | 
            +
                    # Watermarking ensures transparency, dissuades misuse, and enables traceability.
         | 
| 157 | 
            +
                    # Please be a responsible AI citizen and keep the watermarking in place.
         | 
| 158 | 
            +
                    # If using CSM 1B in another application, use your own private key and keep it secret.
         | 
| 159 | 
            +
                    audio, wm_sample_rate = watermark(self._watermarker, audio, self.sample_rate, CSM_1B_HF_WATERMARK)
         | 
| 160 | 
            +
                    audio = torchaudio.functional.resample(audio, orig_freq=wm_sample_rate, new_freq=self.sample_rate)
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                    return audio
         | 
| 163 | 
            +
             | 
| 164 | 
            +
             | 
| 165 | 
            +
            def load_csm_1b(ckpt_path: str = "ckpt.pt", device: str = "cuda") -> Generator:
         | 
| 166 | 
            +
                model_args = ModelArgs(
         | 
| 167 | 
            +
                    backbone_flavor="llama-1B",
         | 
| 168 | 
            +
                    decoder_flavor="llama-100M",
         | 
| 169 | 
            +
                    text_vocab_size=128256,
         | 
| 170 | 
            +
                    audio_vocab_size=2051,
         | 
| 171 | 
            +
                    audio_num_codebooks=32,
         | 
| 172 | 
            +
                )
         | 
| 173 | 
            +
                model = Model(model_args).to(device=device, dtype=torch.bfloat16)
         | 
| 174 | 
            +
                state_dict = torch.load(ckpt_path)
         | 
| 175 | 
            +
                model.load_state_dict(state_dict)
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                generator = Generator(model)
         | 
| 178 | 
            +
                return generator
         | 
    	
        models.py
    ADDED
    
    | @@ -0,0 +1,196 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from dataclasses import dataclass
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import torch.nn as nn
         | 
| 5 | 
            +
            import torchtune
         | 
| 6 | 
            +
            from torchtune.models import llama3_2
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            def llama3_2_1B() -> torchtune.modules.transformer.TransformerDecoder:
         | 
| 10 | 
            +
                return llama3_2.llama3_2(
         | 
| 11 | 
            +
                    vocab_size=128_256,
         | 
| 12 | 
            +
                    num_layers=16,
         | 
| 13 | 
            +
                    num_heads=32,
         | 
| 14 | 
            +
                    num_kv_heads=8,
         | 
| 15 | 
            +
                    embed_dim=2048,
         | 
| 16 | 
            +
                    max_seq_len=2048,
         | 
| 17 | 
            +
                    intermediate_dim=8192,
         | 
| 18 | 
            +
                    attn_dropout=0.0,
         | 
| 19 | 
            +
                    norm_eps=1e-5,
         | 
| 20 | 
            +
                    rope_base=500_000,
         | 
| 21 | 
            +
                    scale_factor=32,
         | 
| 22 | 
            +
                )
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            def llama3_2_100M() -> torchtune.modules.transformer.TransformerDecoder:
         | 
| 26 | 
            +
                return llama3_2.llama3_2(
         | 
| 27 | 
            +
                    vocab_size=128_256,
         | 
| 28 | 
            +
                    num_layers=4,
         | 
| 29 | 
            +
                    num_heads=8,
         | 
| 30 | 
            +
                    num_kv_heads=2,
         | 
| 31 | 
            +
                    embed_dim=1024,
         | 
| 32 | 
            +
                    max_seq_len=2048,
         | 
| 33 | 
            +
                    intermediate_dim=8192,
         | 
| 34 | 
            +
                    attn_dropout=0.0,
         | 
| 35 | 
            +
                    norm_eps=1e-5,
         | 
| 36 | 
            +
                    rope_base=500_000,
         | 
| 37 | 
            +
                    scale_factor=32,
         | 
| 38 | 
            +
                )
         | 
| 39 | 
            +
             | 
| 40 | 
            +
             | 
| 41 | 
            +
            FLAVORS = {
         | 
| 42 | 
            +
                "llama-1B": llama3_2_1B,
         | 
| 43 | 
            +
                "llama-100M": llama3_2_100M,
         | 
| 44 | 
            +
            }
         | 
| 45 | 
            +
             | 
| 46 | 
            +
             | 
| 47 | 
            +
            def _prepare_transformer(model):
         | 
| 48 | 
            +
                embed_dim = model.tok_embeddings.embedding_dim
         | 
| 49 | 
            +
                model.tok_embeddings = nn.Identity()
         | 
| 50 | 
            +
                model.output = nn.Identity()
         | 
| 51 | 
            +
                return model, embed_dim
         | 
| 52 | 
            +
             | 
| 53 | 
            +
             | 
| 54 | 
            +
            def _create_causal_mask(seq_len: int, device: torch.device):
         | 
| 55 | 
            +
                return torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device))
         | 
| 56 | 
            +
             | 
| 57 | 
            +
             | 
| 58 | 
            +
            def _index_causal_mask(mask: torch.Tensor, input_pos: torch.Tensor):
         | 
| 59 | 
            +
                """
         | 
| 60 | 
            +
                Args:
         | 
| 61 | 
            +
                    mask: (max_seq_len, max_seq_len)
         | 
| 62 | 
            +
                    input_pos: (batch_size, seq_len)
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                Returns:
         | 
| 65 | 
            +
                    (batch_size, seq_len, max_seq_len)
         | 
| 66 | 
            +
                """
         | 
| 67 | 
            +
                r = mask[input_pos, :]
         | 
| 68 | 
            +
                return r
         | 
| 69 | 
            +
             | 
| 70 | 
            +
             | 
| 71 | 
            +
            def _multinomial_sample_one_no_sync(probs):  # Does multinomial sampling without a cuda synchronization
         | 
| 72 | 
            +
                q = torch.empty_like(probs).exponential_(1)
         | 
| 73 | 
            +
                return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int)
         | 
| 74 | 
            +
             | 
| 75 | 
            +
             | 
| 76 | 
            +
            def sample_topk(logits: torch.Tensor, topk: int, temperature: float):
         | 
| 77 | 
            +
                logits = logits / temperature
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                filter_value: float = -float("Inf")
         | 
| 80 | 
            +
                indices_to_remove = logits < torch.topk(logits, topk)[0][..., -1, None]
         | 
| 81 | 
            +
                scores_processed = logits.masked_fill(indices_to_remove, filter_value)
         | 
| 82 | 
            +
                scores_processed = torch.nn.functional.log_softmax(scores_processed, dim=-1)
         | 
| 83 | 
            +
                probs = torch.nn.functional.softmax(scores_processed, dim=-1)
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                sample_token = _multinomial_sample_one_no_sync(probs)
         | 
| 86 | 
            +
                return sample_token
         | 
| 87 | 
            +
             | 
| 88 | 
            +
             | 
| 89 | 
            +
            @dataclass
         | 
| 90 | 
            +
            class ModelArgs:
         | 
| 91 | 
            +
                backbone_flavor: str
         | 
| 92 | 
            +
                decoder_flavor: str
         | 
| 93 | 
            +
                text_vocab_size: int
         | 
| 94 | 
            +
                audio_vocab_size: int
         | 
| 95 | 
            +
                audio_num_codebooks: int
         | 
| 96 | 
            +
             | 
| 97 | 
            +
             | 
| 98 | 
            +
            class Model(nn.Module):
         | 
| 99 | 
            +
                def __init__(self, args: ModelArgs):
         | 
| 100 | 
            +
                    super().__init__()
         | 
| 101 | 
            +
                    self.args = args
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                    self.backbone, backbone_dim = _prepare_transformer(FLAVORS[args.backbone_flavor]())
         | 
| 104 | 
            +
                    self.decoder, decoder_dim = _prepare_transformer(FLAVORS[args.decoder_flavor]())
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                    self.text_embeddings = nn.Embedding(args.text_vocab_size, backbone_dim)
         | 
| 107 | 
            +
                    self.audio_embeddings = nn.Embedding(args.audio_vocab_size * args.audio_num_codebooks, backbone_dim)
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                    self.projection = nn.Linear(backbone_dim, decoder_dim, bias=False)
         | 
| 110 | 
            +
                    self.codebook0_head = nn.Linear(backbone_dim, args.audio_vocab_size, bias=False)
         | 
| 111 | 
            +
                    self.audio_head = nn.Parameter(torch.empty(args.audio_num_codebooks - 1, decoder_dim, args.audio_vocab_size))
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                def setup_caches(self, max_batch_size: int) -> torch.Tensor:
         | 
| 114 | 
            +
                    """Setup KV caches and return a causal mask."""
         | 
| 115 | 
            +
                    dtype = next(self.parameters()).dtype
         | 
| 116 | 
            +
                    device = next(self.parameters()).device
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                    with device:
         | 
| 119 | 
            +
                        self.backbone.setup_caches(max_batch_size, dtype)
         | 
| 120 | 
            +
                        self.decoder.setup_caches(max_batch_size, dtype, decoder_max_seq_len=self.args.audio_num_codebooks)
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                    self.register_buffer("backbone_causal_mask", _create_causal_mask(self.backbone.max_seq_len, device))
         | 
| 123 | 
            +
                    self.register_buffer("decoder_causal_mask", _create_causal_mask(self.args.audio_num_codebooks, device))
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                def generate_frame(
         | 
| 126 | 
            +
                    self,
         | 
| 127 | 
            +
                    tokens: torch.Tensor,
         | 
| 128 | 
            +
                    tokens_mask: torch.Tensor,
         | 
| 129 | 
            +
                    input_pos: torch.Tensor,
         | 
| 130 | 
            +
                    temperature: float,
         | 
| 131 | 
            +
                    topk: int,
         | 
| 132 | 
            +
                ) -> torch.Tensor:
         | 
| 133 | 
            +
                    """
         | 
| 134 | 
            +
                    Args:
         | 
| 135 | 
            +
                        tokens: (batch_size, seq_len, audio_num_codebooks+1)
         | 
| 136 | 
            +
                        tokens_mask: (batch_size, seq_len, audio_num_codebooks+1)
         | 
| 137 | 
            +
                        input_pos: (batch_size, seq_len) positions for each token
         | 
| 138 | 
            +
                        mask: (batch_size, seq_len, max_seq_len
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                    Returns:
         | 
| 141 | 
            +
                        (batch_size, audio_num_codebooks) sampled tokens
         | 
| 142 | 
            +
                    """
         | 
| 143 | 
            +
                    dtype = next(self.parameters()).dtype
         | 
| 144 | 
            +
                    b, s, _ = tokens.size()
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                    assert self.backbone.caches_are_enabled(), "backbone caches are not enabled"
         | 
| 147 | 
            +
                    curr_backbone_mask = _index_causal_mask(self.backbone_causal_mask, input_pos)
         | 
| 148 | 
            +
                    embeds = self._embed_tokens(tokens)
         | 
| 149 | 
            +
                    masked_embeds = embeds * tokens_mask.unsqueeze(-1)
         | 
| 150 | 
            +
                    h = masked_embeds.sum(dim=2)
         | 
| 151 | 
            +
                    h = self.backbone(h, input_pos=input_pos, mask=curr_backbone_mask).to(dtype=dtype)
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                    last_h = h[:, -1, :]
         | 
| 154 | 
            +
                    c0_logits = self.codebook0_head(last_h)
         | 
| 155 | 
            +
                    c0_sample = sample_topk(c0_logits, topk, temperature)
         | 
| 156 | 
            +
                    c0_embed = self._embed_audio(0, c0_sample)
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                    curr_h = torch.cat([last_h.unsqueeze(1), c0_embed], dim=1)
         | 
| 159 | 
            +
                    curr_sample = c0_sample.clone()
         | 
| 160 | 
            +
                    curr_pos = torch.arange(0, curr_h.size(1), device=curr_h.device).unsqueeze(0).repeat(curr_h.size(0), 1)
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                    # Decoder caches must be reset every frame.
         | 
| 163 | 
            +
                    self.decoder.reset_caches()
         | 
| 164 | 
            +
                    for i in range(1, self.args.audio_num_codebooks):
         | 
| 165 | 
            +
                        curr_decoder_mask = _index_causal_mask(self.decoder_causal_mask, curr_pos)
         | 
| 166 | 
            +
                        decoder_h = self.decoder(self.projection(curr_h), input_pos=curr_pos, mask=curr_decoder_mask).to(
         | 
| 167 | 
            +
                            dtype=dtype
         | 
| 168 | 
            +
                        )
         | 
| 169 | 
            +
                        ci_logits = torch.mm(decoder_h[:, -1, :], self.audio_head[i - 1])
         | 
| 170 | 
            +
                        ci_sample = sample_topk(ci_logits, topk, temperature)
         | 
| 171 | 
            +
                        ci_embed = self._embed_audio(i, ci_sample)
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                        curr_h = ci_embed
         | 
| 174 | 
            +
                        curr_sample = torch.cat([curr_sample, ci_sample], dim=1)
         | 
| 175 | 
            +
                        curr_pos = curr_pos[:, -1:] + 1
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                    return curr_sample
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                def reset_caches(self):
         | 
| 180 | 
            +
                    self.backbone.reset_caches()
         | 
| 181 | 
            +
                    self.decoder.reset_caches()
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                def _embed_audio(self, codebook: int, tokens: torch.Tensor) -> torch.Tensor:
         | 
| 184 | 
            +
                    return self.audio_embeddings(tokens + codebook * self.args.audio_vocab_size)
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                def _embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
         | 
| 187 | 
            +
                    text_embeds = self.text_embeddings(tokens[:, :, -1]).unsqueeze(-2)
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                    audio_tokens = tokens[:, :, :-1] + (
         | 
| 190 | 
            +
                        self.args.audio_vocab_size * torch.arange(self.args.audio_num_codebooks, device=tokens.device)
         | 
| 191 | 
            +
                    )
         | 
| 192 | 
            +
                    audio_embeds = self.audio_embeddings(audio_tokens.view(-1)).reshape(
         | 
| 193 | 
            +
                        tokens.size(0), tokens.size(1), self.args.audio_num_codebooks, -1
         | 
| 194 | 
            +
                    )
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                    return torch.cat([audio_embeds, text_embeds], dim=-2)
         | 
    	
        prompts/conversational_a.wav
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:356648c1bc6c1da7883004557e9b21a2ef7d01682d8b9d02d6dcb950b348b04f
         | 
| 3 | 
            +
            size 2646044
         | 
    	
        prompts/conversational_b.wav
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:c247153011385d33aaeed193adfec562c32182e2facd30cc8cd0b3e820e94afb
         | 
| 3 | 
            +
            size 2646044
         | 
    	
        prompts/read_speech_a.wav
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:59480708f84c77ab2967d14d821c2ccade9d7761685d060575121f49a149005b
         | 
| 3 | 
            +
            size 831412
         | 
    	
        prompts/read_speech_b.wav
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:f582640265864499cbe6a8c687ea0f9e08e7fa41eeb2caa923d0a3bada55fcef
         | 
| 3 | 
            +
            size 576052
         | 
    	
        prompts/read_speech_c.wav
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:7da15ab3ee7f8bbc8abfce73ce65936a80a535ae4a86db2d9c4756caba69e9c3
         | 
| 3 | 
            +
            size 385964
         | 
    	
        prompts/read_speech_d.wav
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:09cad0494f9d0038b0f0eb039f47d752c45e56d92679f96587e20f67b2c1b7d8
         | 
| 3 | 
            +
            size 435884
         | 
    	
        requirements.txt
    ADDED
    
    | @@ -0,0 +1,11 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            torch==2.4.0
         | 
| 2 | 
            +
            torchaudio==2.4.0
         | 
| 3 | 
            +
            tokenizers==0.21.0
         | 
| 4 | 
            +
            transformers==4.49.0
         | 
| 5 | 
            +
            huggingface_hub==0.28.1
         | 
| 6 | 
            +
            spaces==0.32.0
         | 
| 7 | 
            +
            gradio==5.20.1
         | 
| 8 | 
            +
            moshi==0.2.2
         | 
| 9 | 
            +
            torchtune==0.4.0
         | 
| 10 | 
            +
            torchao==0.9.0
         | 
| 11 | 
            +
            silentcipher @ git+https://github.com/SesameAILabs/silentcipher@master
         | 
    	
        watermarking.py
    ADDED
    
    | @@ -0,0 +1,78 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import argparse
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import silentcipher
         | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
            import torchaudio
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            CSM_1B_HF_WATERMARK = list(map(int, os.getenv("WATERMARK_KEY").split(" ")))
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            def cli_check_audio() -> None:
         | 
| 12 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 13 | 
            +
                parser.add_argument("--audio_path", type=str, required=True)
         | 
| 14 | 
            +
                args = parser.parse_args()
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                check_audio_from_file(args.audio_path)
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            def load_watermarker(device: str = "cuda") -> silentcipher.server.Model:
         | 
| 20 | 
            +
                model = silentcipher.get_model(
         | 
| 21 | 
            +
                    model_type="44.1k",
         | 
| 22 | 
            +
                    device=device,
         | 
| 23 | 
            +
                )
         | 
| 24 | 
            +
                return model
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            @torch.inference_mode()
         | 
| 28 | 
            +
            def watermark(
         | 
| 29 | 
            +
                watermarker: silentcipher.server.Model,
         | 
| 30 | 
            +
                audio_array: torch.Tensor,
         | 
| 31 | 
            +
                sample_rate: int,
         | 
| 32 | 
            +
                watermark_key: list[int],
         | 
| 33 | 
            +
            ) -> tuple[torch.Tensor, int]:
         | 
| 34 | 
            +
                audio_array_44khz = torchaudio.functional.resample(audio_array, orig_freq=sample_rate, new_freq=44100)
         | 
| 35 | 
            +
                encoded, _ = watermarker.encode_wav(audio_array_44khz, 44100, watermark_key, calc_sdr=False, message_sdr=36)
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                output_sample_rate = min(44100, sample_rate)
         | 
| 38 | 
            +
                encoded = torchaudio.functional.resample(encoded, orig_freq=44100, new_freq=output_sample_rate)
         | 
| 39 | 
            +
                return encoded, output_sample_rate
         | 
| 40 | 
            +
             | 
| 41 | 
            +
             | 
| 42 | 
            +
            @torch.inference_mode()
         | 
| 43 | 
            +
            def verify(
         | 
| 44 | 
            +
                watermarker: silentcipher.server.Model,
         | 
| 45 | 
            +
                watermarked_audio: torch.Tensor,
         | 
| 46 | 
            +
                sample_rate: int,
         | 
| 47 | 
            +
                watermark_key: list[int],
         | 
| 48 | 
            +
            ) -> bool:
         | 
| 49 | 
            +
                watermarked_audio_44khz = torchaudio.functional.resample(watermarked_audio, orig_freq=sample_rate, new_freq=44100)
         | 
| 50 | 
            +
                result = watermarker.decode_wav(watermarked_audio_44khz, 44100, phase_shift_decoding=True)
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                is_watermarked = result["status"]
         | 
| 53 | 
            +
                if is_watermarked:
         | 
| 54 | 
            +
                    is_csm_watermarked = result["messages"][0] == watermark_key
         | 
| 55 | 
            +
                else:
         | 
| 56 | 
            +
                    is_csm_watermarked = False
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                return is_watermarked and is_csm_watermarked
         | 
| 59 | 
            +
             | 
| 60 | 
            +
             | 
| 61 | 
            +
            def check_audio_from_file(audio_path: str) -> None:
         | 
| 62 | 
            +
                watermarker = load_watermarker(device="cuda")
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                audio_array, sample_rate = load_audio(audio_path)
         | 
| 65 | 
            +
                is_watermarked = verify(watermarker, audio_array, sample_rate, CSM_1B_HF_WATERMARK)
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                outcome = "Watermarked" if is_watermarked else "Not watermarked"
         | 
| 68 | 
            +
                print(f"{outcome}: {audio_path}")
         | 
| 69 | 
            +
             | 
| 70 | 
            +
             | 
| 71 | 
            +
            def load_audio(audio_path: str) -> tuple[torch.Tensor, int]:
         | 
| 72 | 
            +
                audio_array, sample_rate = torchaudio.load(audio_path)
         | 
| 73 | 
            +
                audio_array = audio_array.mean(dim=0)
         | 
| 74 | 
            +
                return audio_array, int(sample_rate)
         | 
| 75 | 
            +
             | 
| 76 | 
            +
             | 
| 77 | 
            +
            if __name__ == "__main__":
         | 
| 78 | 
            +
                cli_check_audio()
         | 
