Spaces:
				
			
			
	
			
			
		Paused
		
	
	
	
			
			
	
	
	
	
		
		
		Paused
		
	Upload 409 files
Browse filesThis view is limited to 50 files because it contains too many changes.  
							See raw diff
- .gitattributes +3 -0
 - README.md +169 -14
 - models/__init__.py +0 -0
 - models/base/__init__.py +7 -0
 - models/base/base_dataset.py +464 -0
 - models/base/base_inference.py +220 -0
 - models/base/base_sampler.py +157 -0
 - models/base/base_trainer.py +348 -0
 - models/base/new_dataset.py +50 -0
 - models/base/new_inference.py +253 -0
 - models/base/new_trainer.py +727 -0
 - models/codec/__init__.py +0 -0
 - models/codec/amphion_codec/codec.py +427 -0
 - models/codec/amphion_codec/quantize/__init__.py +11 -0
 - models/codec/amphion_codec/quantize/factorized_vector_quantize.py +150 -0
 - models/codec/amphion_codec/quantize/lookup_free_quantize.py +77 -0
 - models/codec/amphion_codec/quantize/residual_vq.py +177 -0
 - models/codec/amphion_codec/quantize/vector_quantize.py +401 -0
 - models/codec/amphion_codec/vocos.py +881 -0
 - models/codec/codec_dataset.py +264 -0
 - models/codec/codec_inference.py +515 -0
 - models/codec/codec_sampler.py +126 -0
 - models/codec/codec_trainer.py +166 -0
 - models/codec/facodec/__init__.py +0 -0
 - models/codec/facodec/alias_free_torch/__init__.py +5 -0
 - models/codec/facodec/alias_free_torch/act.py +29 -0
 - models/codec/facodec/alias_free_torch/filter.py +96 -0
 - models/codec/facodec/alias_free_torch/resample.py +57 -0
 - models/codec/facodec/facodec_dataset.py +98 -0
 - models/codec/facodec/facodec_inference.py +137 -0
 - models/codec/facodec/facodec_trainer.py +776 -0
 - models/codec/facodec/modules/JDC/__init__.py +1 -0
 - models/codec/facodec/modules/JDC/bst.t7 +3 -0
 - models/codec/facodec/modules/JDC/model.py +219 -0
 - models/codec/facodec/modules/attentions.py +437 -0
 - models/codec/facodec/modules/commons.py +331 -0
 - models/codec/facodec/modules/gradient_reversal.py +35 -0
 - models/codec/facodec/modules/layers.py +460 -0
 - models/codec/facodec/modules/quantize.py +741 -0
 - models/codec/facodec/modules/style_encoder.py +110 -0
 - models/codec/facodec/modules/wavenet.py +224 -0
 - models/codec/facodec/optimizer.py +104 -0
 - models/codec/kmeans/repcodec_model.py +210 -0
 - models/codec/kmeans/vocos.py +850 -0
 - models/codec/ns3_codec/README.md +216 -0
 - models/codec/ns3_codec/__init__.py +6 -0
 - models/codec/ns3_codec/alias_free_torch/__init__.py +5 -0
 - models/codec/ns3_codec/alias_free_torch/act.py +29 -0
 - models/codec/ns3_codec/alias_free_torch/filter.py +96 -0
 - models/codec/ns3_codec/alias_free_torch/resample.py +57 -0
 
    	
        .gitattributes
    CHANGED
    
    | 
         @@ -34,3 +34,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text 
     | 
|
| 34 | 
         
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         
     | 
| 35 | 
         
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         
     | 
| 36 | 
         
             
            imgs/vocoder/gan/MSSBCQTD.png filter=lfs diff=lfs merge=lfs -text
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 34 | 
         
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         
     | 
| 35 | 
         
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         
     | 
| 36 | 
         
             
            imgs/vocoder/gan/MSSBCQTD.png filter=lfs diff=lfs merge=lfs -text
         
     | 
| 37 | 
         
            +
            models/codec/facodec/modules/JDC/bst.t7 filter=lfs diff=lfs merge=lfs -text
         
     | 
| 38 | 
         
            +
            models/tts/maskgct/g2p/sources/chinese_lexicon.txt filter=lfs diff=lfs merge=lfs -text
         
     | 
| 39 | 
         
            +
            models/tts/maskgct/wav/prompt.wav filter=lfs diff=lfs merge=lfs -text
         
     | 
    	
        README.md
    CHANGED
    
    | 
         @@ -1,14 +1,169 @@ 
     | 
|
| 1 | 
         
            -
             
     | 
| 2 | 
         
            -
             
     | 
| 3 | 
         
            -
             
     | 
| 4 | 
         
            -
             
     | 
| 5 | 
         
            -
             
     | 
| 6 | 
         
            -
             
     | 
| 7 | 
         
            -
             
     | 
| 8 | 
         
            -
             
     | 
| 9 | 
         
            -
             
     | 
| 10 | 
         
            -
             
     | 
| 11 | 
         
            -
             
     | 
| 12 | 
         
            -
             
     | 
| 13 | 
         
            -
             
     | 
| 14 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Amphion: An Open-Source Audio, Music, and Speech Generation Toolkit
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            <div>
         
     | 
| 4 | 
         
            +
                <a href="https://arxiv.org/abs/2312.09911"><img src="https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg"></a>
         
     | 
| 5 | 
         
            +
                <a href="https://huggingface.co/amphion"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Amphion-pink"></a>
         
     | 
| 6 | 
         
            +
                <a href="https://openxlab.org.cn/usercenter/Amphion"><img src="https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg"></a>
         
     | 
| 7 | 
         
            +
                <a href="https://discord.com/invite/ZxxREr3Y"><img src="https://img.shields.io/badge/Discord-Join%20chat-blue.svg"></a>
         
     | 
| 8 | 
         
            +
                <a href="egs/tts/README.md"><img src="https://img.shields.io/badge/README-TTS-blue"></a>
         
     | 
| 9 | 
         
            +
                <a href="egs/svc/README.md"><img src="https://img.shields.io/badge/README-SVC-blue"></a>
         
     | 
| 10 | 
         
            +
                <a href="egs/tta/README.md"><img src="https://img.shields.io/badge/README-TTA-blue"></a>
         
     | 
| 11 | 
         
            +
                <a href="egs/vocoder/README.md"><img src="https://img.shields.io/badge/README-Vocoder-purple"></a>
         
     | 
| 12 | 
         
            +
                <a href="egs/metrics/README.md"><img src="https://img.shields.io/badge/README-Evaluation-yellow"></a>
         
     | 
| 13 | 
         
            +
                <a href="LICENSE"><img src="https://img.shields.io/badge/LICENSE-MIT-red"></a>
         
     | 
| 14 | 
         
            +
            </div>
         
     | 
| 15 | 
         
            +
            <br>
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            **Amphion (/æmˈfaɪən/) is a toolkit for Audio, Music, and Speech Generation.** Its purpose is to support reproducible research and help junior researchers and engineers get started in the field of audio, music, and speech generation research and development. Amphion offers a unique feature: **visualizations** of classic models or architectures. We believe that these visualizations are beneficial for junior researchers and engineers who wish to gain a better understanding of the model.
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            **The North-Star objective of Amphion is to offer a platform for studying the conversion of any inputs into audio.** Amphion is designed to support individual generation tasks, including but not limited to,
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            - **TTS**: Text to Speech (⛳ supported)
         
     | 
| 22 | 
         
            +
            - **SVS**: Singing Voice Synthesis (👨💻 developing)
         
     | 
| 23 | 
         
            +
            - **VC**: Voice Conversion (👨💻 developing)
         
     | 
| 24 | 
         
            +
            - **SVC**: Singing Voice Conversion (⛳ supported)
         
     | 
| 25 | 
         
            +
            - **TTA**: Text to Audio (⛳ supported)
         
     | 
| 26 | 
         
            +
            - **TTM**: Text to Music (👨💻 developing)
         
     | 
| 27 | 
         
            +
            - more…
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            In addition to the specific generation tasks, Amphion includes several **vocoders** and **evaluation metrics**. A vocoder is an important module for producing high-quality audio signals, while evaluation metrics are critical for ensuring consistent metrics in generation tasks. Moreover, Amphion is dedicated to advancing audio generation in real-world applications, such as building **large-scale datasets** for speech synthesis.
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
            ## 🚀 News
         
     | 
| 32 | 
         
            +
            - **2024/10/19**: We release **MaskGCT**, a fully non-autoregressive TTS model that eliminates the need for explicit alignment information between text and speech supervision. MaskGCT is trained on Emilia dataset and achieves SOTA zero-shot TTS perfermance.  [](https://arxiv.org/abs/2409.00750) [](https://huggingface.co/amphion/maskgct) [](https://huggingface.co/spaces/amphion/maskgct) [](models/tts/maskgct/README.md)
         
     | 
| 33 | 
         
            +
            - **2024/09/01**: [Amphion](https://arxiv.org/abs/2312.09911), [Emilia](https://arxiv.org/abs/2407.05361) and [DSFF-SVC](https://arxiv.org/abs/2310.11160) got accepted by IEEE SLT 2024! 🤗
         
     | 
| 34 | 
         
            +
            - **2024/08/28**: Welcome to join Amphion's [Discord channel](https://discord.gg/drhW7ajqAG) to stay connected and engage with our community!
         
     | 
| 35 | 
         
            +
            - **2024/08/20**: [SingVisio](https://arxiv.org/abs/2402.12660) got accepted by Computers & Graphics, [available here](https://www.sciencedirect.com/science/article/pii/S0097849324001936)! 🎉
         
     | 
| 36 | 
         
            +
            - **2024/08/27**: *The Emilia dataset is now publicly available!* Discover the most extensive and diverse speech generation dataset with 101k hours of in-the-wild speech data now at [](https://huggingface.co/datasets/amphion/Emilia-Dataset) or [](https://opendatalab.com/Amphion/Emilia)! 👑👑👑
         
     | 
| 37 | 
         
            +
            - **2024/07/01**: Amphion now releases **Emilia**, the first open-source multilingual in-the-wild dataset for speech generation with over 101k hours of speech data, and the **Emilia-Pipe**, the first open-source preprocessing pipeline designed to transform in-the-wild speech data into high-quality training data with annotations for speech generation! [](https://arxiv.org/abs/2407.05361) [](https://huggingface.co/datasets/amphion/Emilia) [](https://emilia-dataset.github.io/Emilia-Demo-Page/) [](preprocessors/Emilia/README.md)
         
     | 
| 38 | 
         
            +
            - **2024/06/17**: Amphion has a new release for its **VALL-E** model! It uses Llama as its underlying architecture and has better model performance, faster training speed, and more readable codes compared to our first version. [](egs/tts/VALLE_V2/README.md)
         
     | 
| 39 | 
         
            +
            - **2024/03/12**: Amphion now support **NaturalSpeech3 FACodec** and release pretrained checkpoints. [](https://arxiv.org/abs/2403.03100) [](https://huggingface.co/amphion/naturalspeech3_facodec) [](https://huggingface.co/spaces/amphion/naturalspeech3_facodec) [](models/codec/ns3_codec/README.md)
         
     | 
| 40 | 
         
            +
            - **2024/02/22**: The first Amphion visualization tool, **SingVisio**, release. [](https://arxiv.org/abs/2402.12660) [](https://openxlab.org.cn/apps/detail/Amphion/SingVisio) [](https://drive.google.com/file/d/15097SGhQh-SwUNbdWDYNyWEP--YGLba5/view) [](egs/visualization/SingVisio/README.md)
         
     | 
| 41 | 
         
            +
            - **2023/12/18**: Amphion v0.1 release. [](https://arxiv.org/abs/2312.09911) [](https://huggingface.co/amphion) [](https://www.youtube.com/watch?v=1aw0HhcggvQ) [](https://github.com/open-mmlab/Amphion/pull/39)
         
     | 
| 42 | 
         
            +
            - **2023/11/28**: Amphion alpha release. [](https://github.com/open-mmlab/Amphion/pull/2)
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
            ## ⭐ Key Features
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
            ### TTS: Text to Speech
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
            - Amphion achieves state-of-the-art performance compared to existing open-source repositories on text-to-speech (TTS) systems. It supports the following models or architectures:
         
     | 
| 49 | 
         
            +
                - [FastSpeech2](https://arxiv.org/abs/2006.04558): A non-autoregressive TTS architecture that utilizes feed-forward Transformer blocks.
         
     | 
| 50 | 
         
            +
                - [VITS](https://arxiv.org/abs/2106.06103): An end-to-end TTS architecture that utilizes conditional variational autoencoder with adversarial learning
         
     | 
| 51 | 
         
            +
                - [VALL-E](https://arxiv.org/abs/2301.02111): A zero-shot TTS architecture that uses a neural codec language model with discrete codes.
         
     | 
| 52 | 
         
            +
                - [NaturalSpeech2](https://arxiv.org/abs/2304.09116): An architecture for TTS that utilizes a latent diffusion model to generate natural-sounding voices.
         
     | 
| 53 | 
         
            +
                - [Jets](Jets): An end-to-end TTS model that jointly trains FastSpeech2 and HiFi-GAN with an alignment module.
         
     | 
| 54 | 
         
            +
                - [MaskGCT](https://arxiv.org/abs/2409.00750): a fully non-autoregressive TTS architecture that eliminates the need for explicit alignment information between text and speech supervision.
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
            ### SVC: Singing Voice Conversion
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
            - Ampion supports multiple content-based features from various pretrained models, including [WeNet](https://github.com/wenet-e2e/wenet), [Whisper](https://github.com/openai/whisper), and [ContentVec](https://github.com/auspicious3000/contentvec). Their specific roles in SVC has been investigated in our SLT 2024 paper. [](https://arxiv.org/abs/2310.11160) [](egs/svc/MultipleContentsSVC)
         
     | 
| 59 | 
         
            +
            - Amphion implements several state-of-the-art model architectures, including diffusion-, transformer-, VAE- and flow-based models. The diffusion-based architecture uses [Bidirectional dilated CNN](https://openreview.net/pdf?id=a-xFK8Ymz5J) as a backend and supports several sampling algorithms such as [DDPM](https://arxiv.org/pdf/2006.11239.pdf), [DDIM](https://arxiv.org/pdf/2010.02502.pdf), and [PNDM](https://arxiv.org/pdf/2202.09778.pdf). Additionally, it supports single-step inference based on the [Consistency Model](https://openreview.net/pdf?id=FmqFfMTNnv).
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
            ### TTA: Text to Audio
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
            - Amphion supports the TTA with a latent diffusion model. It is designed like [AudioLDM](https://arxiv.org/abs/2301.12503), [Make-an-Audio](https://arxiv.org/abs/2301.12661), and [AUDIT](https://arxiv.org/abs/2304.00830). It is also the official implementation of the text-to-audio generation part of our NeurIPS 2023 paper. [](https://arxiv.org/abs/2304.00830) [](egs/tta/RECIPE.md)
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
            ### Vocoder
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
            - Amphion supports various widely-used neural vocoders, including:
         
     | 
| 68 | 
         
            +
                - GAN-based vocoders: [MelGAN](https://arxiv.org/abs/1910.06711), [HiFi-GAN](https://arxiv.org/abs/2010.05646), [NSF-HiFiGAN](https://github.com/nii-yamagishilab/project-NN-Pytorch-scripts), [BigVGAN](https://arxiv.org/abs/2206.04658), [APNet](https://arxiv.org/abs/2305.07952).
         
     | 
| 69 | 
         
            +
                - Flow-based vocoders: [WaveGlow](https://arxiv.org/abs/1811.00002).
         
     | 
| 70 | 
         
            +
                - Diffusion-based vocoders: [Diffwave](https://arxiv.org/abs/2009.09761).
         
     | 
| 71 | 
         
            +
                - Auto-regressive based vocoders: [WaveNet](https://arxiv.org/abs/1609.03499), [WaveRNN](https://arxiv.org/abs/1802.08435v1).
         
     | 
| 72 | 
         
            +
            - Amphion provides the official implementation of [Multi-Scale Constant-Q Transform Discriminator](https://arxiv.org/abs/2311.14957) (our ICASSP 2024 paper). It can be used to enhance any architecture GAN-based vocoders during training, and keep the inference stage (such as memory or speed) unchanged. [](https://arxiv.org/abs/2311.14957) [](egs/vocoder/gan/tfr_enhanced_hifigan)
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
            ### Evaluation
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
            Amphion provides a comprehensive objective evaluation of the generated audio. The evaluation metrics contain:
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
            - **F0 Modeling**: F0 Pearson Coefficients, F0 Periodicity Root Mean Square Error, F0 Root Mean Square Error, Voiced/Unvoiced F1 Score, etc.
         
     | 
| 79 | 
         
            +
            - **Energy Modeling**: Energy Root Mean Square Error, Energy Pearson Coefficients, etc.
         
     | 
| 80 | 
         
            +
            - **Intelligibility**: Character/Word Error Rate, which can be calculated based on [Whisper](https://github.com/openai/whisper) and more.
         
     | 
| 81 | 
         
            +
            - **Spectrogram Distortion**: Frechet Audio Distance (FAD), Mel Cepstral Distortion (MCD), Multi-Resolution STFT Distance (MSTFT), Perceptual Evaluation of Speech Quality (PESQ), Short Time Objective Intelligibility (STOI), etc.
         
     | 
| 82 | 
         
            +
            - **Speaker Similarity**: Cosine similarity, which can be calculated based on [RawNet3](https://github.com/Jungjee/RawNet), [Resemblyzer](https://github.com/resemble-ai/Resemblyzer), [WeSpeaker](https://github.com/wenet-e2e/wespeaker), [WavLM](https://github.com/microsoft/unilm/tree/master/wavlm) and more.
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
            ### Datasets
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
            - Amphion unifies the data preprocess of the open-source datasets including [AudioCaps](https://audiocaps.github.io/), [LibriTTS](https://www.openslr.org/60/), [LJSpeech](https://keithito.com/LJ-Speech-Dataset/), [M4Singer](https://github.com/M4Singer/M4Singer), [Opencpop](https://wenet.org.cn/opencpop/), [OpenSinger](https://github.com/Multi-Singer/Multi-Singer.github.io), [SVCC](http://vc-challenge.org/), [VCTK](https://datashare.ed.ac.uk/handle/10283/3443), and more. The supported dataset list can be seen [here](egs/datasets/README.md) (updating). 
         
     | 
| 87 | 
         
            +
            - Amphion (exclusively) supports the [**Emilia**](preprocessors/Emilia/README.md) dataset and its preprocessing pipeline **Emilia-Pipe** for in-the-wild speech data!
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
            ### Visualization
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
            Amphion provides visualization tools to interactively illustrate the internal processing mechanism of classic models. This provides an invaluable resource for educational purposes and for facilitating understandable research.
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
            Currently, Amphion supports [SingVisio](egs/visualization/SingVisio/README.md), a visualization tool of the diffusion model for singing voice conversion. [](https://arxiv.org/abs/2402.12660) [](https://openxlab.org.cn/apps/detail/Amphion/SingVisio) [](https://drive.google.com/file/d/15097SGhQh-SwUNbdWDYNyWEP--YGLba5/view)
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
            ## 📀 Installation
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
            Amphion can be installed through either Setup Installer or Docker Image.
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
            ### Setup Installer
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
            ```bash
         
     | 
| 103 | 
         
            +
            git clone https://github.com/open-mmlab/Amphion.git
         
     | 
| 104 | 
         
            +
            cd Amphion
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
            # Install Python Environment
         
     | 
| 107 | 
         
            +
            conda create --name amphion python=3.9.15
         
     | 
| 108 | 
         
            +
            conda activate amphion
         
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
            # Install Python Packages Dependencies
         
     | 
| 111 | 
         
            +
            sh env.sh
         
     | 
| 112 | 
         
            +
            ```
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
            ### Docker Image
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
            1. Install [Docker](https://docs.docker.com/get-docker/), [NVIDIA Driver](https://www.nvidia.com/download/index.aspx), [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html), and [CUDA](https://developer.nvidia.com/cuda-downloads).
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
            2. Run the following commands:
         
     | 
| 119 | 
         
            +
            ```bash
         
     | 
| 120 | 
         
            +
            git clone https://github.com/open-mmlab/Amphion.git
         
     | 
| 121 | 
         
            +
            cd Amphion
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
            docker pull realamphion/amphion
         
     | 
| 124 | 
         
            +
            docker run --runtime=nvidia --gpus all -it -v .:/app realamphion/amphion
         
     | 
| 125 | 
         
            +
            ```
         
     | 
| 126 | 
         
            +
            Mount dataset by argument `-v` is necessary when using Docker. Please refer to [Mount dataset in Docker container](egs/datasets/docker.md) and [Docker Docs](https://docs.docker.com/engine/reference/commandline/container_run/#volume) for more details.
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
            ## 🐍 Usage in Python
         
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
            We detail the instructions of different tasks in the following recipes:
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
            - [Text to Speech (TTS)](egs/tts/README.md)
         
     | 
| 134 | 
         
            +
            - [Singing Voice Conversion (SVC)](egs/svc/README.md)
         
     | 
| 135 | 
         
            +
            - [Text to Audio (TTA)](egs/tta/README.md)
         
     | 
| 136 | 
         
            +
            - [Vocoder](egs/vocoder/README.md)
         
     | 
| 137 | 
         
            +
            - [Evaluation](egs/metrics/README.md)
         
     | 
| 138 | 
         
            +
            - [Visualization](egs/visualization/README.md)
         
     | 
| 139 | 
         
            +
             
     | 
| 140 | 
         
            +
            ## 👨💻 Contributing
         
     | 
| 141 | 
         
            +
            We appreciate all contributions to improve Amphion. Please refer to [CONTRIBUTING.md](.github/CONTRIBUTING.md) for the contributing guideline.
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
            ## 🙏 Acknowledgement
         
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
            - [ming024's FastSpeech2](https://github.com/ming024/FastSpeech2) and [jaywalnut310's VITS](https://github.com/jaywalnut310/vits) for model architecture code.
         
     | 
| 147 | 
         
            +
            - [lifeiteng's VALL-E](https://github.com/lifeiteng/vall-e) for training pipeline and model architecture design.
         
     | 
| 148 | 
         
            +
            - [SpeechTokenizer](https://github.com/ZhangXInFD/SpeechTokenizer) for semantic-distilled tokenizer design.
         
     | 
| 149 | 
         
            +
            - [WeNet](https://github.com/wenet-e2e/wenet), [Whisper](https://github.com/openai/whisper), [ContentVec](https://github.com/auspicious3000/contentvec), and [RawNet3](https://github.com/Jungjee/RawNet) for pretrained models and inference code.
         
     | 
| 150 | 
         
            +
            - [HiFi-GAN](https://github.com/jik876/hifi-gan) for GAN-based Vocoder's architecture design and training strategy.
         
     | 
| 151 | 
         
            +
            - [Encodec](https://github.com/facebookresearch/encodec) for well-organized GAN Discriminator's architecture and basic blocks.
         
     | 
| 152 | 
         
            +
            - [Latent Diffusion](https://github.com/CompVis/latent-diffusion) for model architecture design.
         
     | 
| 153 | 
         
            +
            - [TensorFlowTTS](https://github.com/TensorSpeech/TensorFlowTTS) for preparing the MFA tools.
         
     | 
| 154 | 
         
            +
             
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
            ## ©️ License
         
     | 
| 157 | 
         
            +
             
     | 
| 158 | 
         
            +
            Amphion is under the [MIT License](LICENSE). It is free for both research and commercial use cases.
         
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
            ## 📚 Citations
         
     | 
| 161 | 
         
            +
             
     | 
| 162 | 
         
            +
            ```bibtex
         
     | 
| 163 | 
         
            +
            @inproceedings{amphion,
         
     | 
| 164 | 
         
            +
                author={Zhang, Xueyao and Xue, Liumeng and Gu, Yicheng and Wang, Yuancheng and Li, Jiaqi and He, Haorui and Wang, Chaoren and Song, Ting and Chen, Xi and Fang, Zihao and Chen, Haopeng and Zhang, Junan and Tang, Tze Ying and Zou, Lexiao and Wang, Mingxuan and Han, Jun and Chen, Kai and Li, Haizhou and Wu, Zhizheng},
         
     | 
| 165 | 
         
            +
                title={Amphion: An Open-Source Audio, Music and Speech Generation Toolkit},
         
     | 
| 166 | 
         
            +
                booktitle={{IEEE} Spoken Language Technology Workshop, {SLT} 2024},
         
     | 
| 167 | 
         
            +
                year={2024}
         
     | 
| 168 | 
         
            +
            }
         
     | 
| 169 | 
         
            +
            ```
         
     | 
    	
        models/__init__.py
    ADDED
    
    | 
         
            File without changes
         
     | 
    	
        models/base/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,7 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2023 Amphion.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            from .new_trainer import BaseTrainer
         
     | 
| 7 | 
         
            +
            from .new_inference import BaseInference
         
     | 
    	
        models/base/base_dataset.py
    ADDED
    
    | 
         @@ -0,0 +1,464 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2023 Amphion.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import torch
         
     | 
| 7 | 
         
            +
            import numpy as np
         
     | 
| 8 | 
         
            +
            import torch.utils.data
         
     | 
| 9 | 
         
            +
            from torch.nn.utils.rnn import pad_sequence
         
     | 
| 10 | 
         
            +
            import librosa
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            from utils.data_utils import *
         
     | 
| 13 | 
         
            +
            from processors.acoustic_extractor import cal_normalized_mel
         
     | 
| 14 | 
         
            +
            from text import text_to_sequence
         
     | 
| 15 | 
         
            +
            from text.text_token_collation import phoneIDCollation
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            class BaseOfflineDataset(torch.utils.data.Dataset):
         
     | 
| 19 | 
         
            +
                def __init__(self, cfg, dataset, is_valid=False):
         
     | 
| 20 | 
         
            +
                    """
         
     | 
| 21 | 
         
            +
                    Args:
         
     | 
| 22 | 
         
            +
                        cfg: config
         
     | 
| 23 | 
         
            +
                        dataset: dataset name
         
     | 
| 24 | 
         
            +
                        is_valid: whether to use train or valid dataset
         
     | 
| 25 | 
         
            +
                    """
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                    assert isinstance(dataset, str)
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                    # self.data_root = processed_data_dir
         
     | 
| 30 | 
         
            +
                    self.cfg = cfg
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                    processed_data_dir = os.path.join(cfg.preprocess.processed_dir, dataset)
         
     | 
| 33 | 
         
            +
                    meta_file = cfg.preprocess.valid_file if is_valid else cfg.preprocess.train_file
         
     | 
| 34 | 
         
            +
                    self.metafile_path = os.path.join(processed_data_dir, meta_file)
         
     | 
| 35 | 
         
            +
                    self.metadata = self.get_metadata()
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                    """
         
     | 
| 38 | 
         
            +
                    load spk2id and utt2spk from json file
         
     | 
| 39 | 
         
            +
                        spk2id: {spk1: 0, spk2: 1, ...}
         
     | 
| 40 | 
         
            +
                        utt2spk: {dataset_uid: spk1, ...}
         
     | 
| 41 | 
         
            +
                    """
         
     | 
| 42 | 
         
            +
                    if cfg.preprocess.use_spkid:
         
     | 
| 43 | 
         
            +
                        spk2id_path = os.path.join(processed_data_dir, cfg.preprocess.spk2id)
         
     | 
| 44 | 
         
            +
                        with open(spk2id_path, "r") as f:
         
     | 
| 45 | 
         
            +
                            self.spk2id = json.load(f)
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                        utt2spk_path = os.path.join(processed_data_dir, cfg.preprocess.utt2spk)
         
     | 
| 48 | 
         
            +
                        self.utt2spk = dict()
         
     | 
| 49 | 
         
            +
                        with open(utt2spk_path, "r") as f:
         
     | 
| 50 | 
         
            +
                            for line in f.readlines():
         
     | 
| 51 | 
         
            +
                                utt, spk = line.strip().split("\t")
         
     | 
| 52 | 
         
            +
                                self.utt2spk[utt] = spk
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                    if cfg.preprocess.use_uv:
         
     | 
| 55 | 
         
            +
                        self.utt2uv_path = {}
         
     | 
| 56 | 
         
            +
                        for utt_info in self.metadata:
         
     | 
| 57 | 
         
            +
                            dataset = utt_info["Dataset"]
         
     | 
| 58 | 
         
            +
                            uid = utt_info["Uid"]
         
     | 
| 59 | 
         
            +
                            utt = "{}_{}".format(dataset, uid)
         
     | 
| 60 | 
         
            +
                            self.utt2uv_path[utt] = os.path.join(
         
     | 
| 61 | 
         
            +
                                cfg.preprocess.processed_dir,
         
     | 
| 62 | 
         
            +
                                dataset,
         
     | 
| 63 | 
         
            +
                                cfg.preprocess.uv_dir,
         
     | 
| 64 | 
         
            +
                                uid + ".npy",
         
     | 
| 65 | 
         
            +
                            )
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
                    if cfg.preprocess.use_frame_pitch:
         
     | 
| 68 | 
         
            +
                        self.utt2frame_pitch_path = {}
         
     | 
| 69 | 
         
            +
                        for utt_info in self.metadata:
         
     | 
| 70 | 
         
            +
                            dataset = utt_info["Dataset"]
         
     | 
| 71 | 
         
            +
                            uid = utt_info["Uid"]
         
     | 
| 72 | 
         
            +
                            utt = "{}_{}".format(dataset, uid)
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                            self.utt2frame_pitch_path[utt] = os.path.join(
         
     | 
| 75 | 
         
            +
                                cfg.preprocess.processed_dir,
         
     | 
| 76 | 
         
            +
                                dataset,
         
     | 
| 77 | 
         
            +
                                cfg.preprocess.pitch_dir,
         
     | 
| 78 | 
         
            +
                                uid + ".npy",
         
     | 
| 79 | 
         
            +
                            )
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                    if cfg.preprocess.use_frame_energy:
         
     | 
| 82 | 
         
            +
                        self.utt2frame_energy_path = {}
         
     | 
| 83 | 
         
            +
                        for utt_info in self.metadata:
         
     | 
| 84 | 
         
            +
                            dataset = utt_info["Dataset"]
         
     | 
| 85 | 
         
            +
                            uid = utt_info["Uid"]
         
     | 
| 86 | 
         
            +
                            utt = "{}_{}".format(dataset, uid)
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                            self.utt2frame_energy_path[utt] = os.path.join(
         
     | 
| 89 | 
         
            +
                                cfg.preprocess.processed_dir,
         
     | 
| 90 | 
         
            +
                                dataset,
         
     | 
| 91 | 
         
            +
                                cfg.preprocess.energy_dir,
         
     | 
| 92 | 
         
            +
                                uid + ".npy",
         
     | 
| 93 | 
         
            +
                            )
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
                    if cfg.preprocess.use_mel:
         
     | 
| 96 | 
         
            +
                        self.utt2mel_path = {}
         
     | 
| 97 | 
         
            +
                        for utt_info in self.metadata:
         
     | 
| 98 | 
         
            +
                            dataset = utt_info["Dataset"]
         
     | 
| 99 | 
         
            +
                            uid = utt_info["Uid"]
         
     | 
| 100 | 
         
            +
                            utt = "{}_{}".format(dataset, uid)
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                            self.utt2mel_path[utt] = os.path.join(
         
     | 
| 103 | 
         
            +
                                cfg.preprocess.processed_dir,
         
     | 
| 104 | 
         
            +
                                dataset,
         
     | 
| 105 | 
         
            +
                                cfg.preprocess.mel_dir,
         
     | 
| 106 | 
         
            +
                                uid + ".npy",
         
     | 
| 107 | 
         
            +
                            )
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
                    if cfg.preprocess.use_linear:
         
     | 
| 110 | 
         
            +
                        self.utt2linear_path = {}
         
     | 
| 111 | 
         
            +
                        for utt_info in self.metadata:
         
     | 
| 112 | 
         
            +
                            dataset = utt_info["Dataset"]
         
     | 
| 113 | 
         
            +
                            uid = utt_info["Uid"]
         
     | 
| 114 | 
         
            +
                            utt = "{}_{}".format(dataset, uid)
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
                            self.utt2linear_path[utt] = os.path.join(
         
     | 
| 117 | 
         
            +
                                cfg.preprocess.processed_dir,
         
     | 
| 118 | 
         
            +
                                dataset,
         
     | 
| 119 | 
         
            +
                                cfg.preprocess.linear_dir,
         
     | 
| 120 | 
         
            +
                                uid + ".npy",
         
     | 
| 121 | 
         
            +
                            )
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
                    if cfg.preprocess.use_audio:
         
     | 
| 124 | 
         
            +
                        self.utt2audio_path = {}
         
     | 
| 125 | 
         
            +
                        for utt_info in self.metadata:
         
     | 
| 126 | 
         
            +
                            dataset = utt_info["Dataset"]
         
     | 
| 127 | 
         
            +
                            uid = utt_info["Uid"]
         
     | 
| 128 | 
         
            +
                            utt = "{}_{}".format(dataset, uid)
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
                            self.utt2audio_path[utt] = os.path.join(
         
     | 
| 131 | 
         
            +
                                cfg.preprocess.processed_dir,
         
     | 
| 132 | 
         
            +
                                dataset,
         
     | 
| 133 | 
         
            +
                                cfg.preprocess.audio_dir,
         
     | 
| 134 | 
         
            +
                                uid + ".npy",
         
     | 
| 135 | 
         
            +
                            )
         
     | 
| 136 | 
         
            +
                    elif cfg.preprocess.use_label:
         
     | 
| 137 | 
         
            +
                        self.utt2label_path = {}
         
     | 
| 138 | 
         
            +
                        for utt_info in self.metadata:
         
     | 
| 139 | 
         
            +
                            dataset = utt_info["Dataset"]
         
     | 
| 140 | 
         
            +
                            uid = utt_info["Uid"]
         
     | 
| 141 | 
         
            +
                            utt = "{}_{}".format(dataset, uid)
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
                            self.utt2label_path[utt] = os.path.join(
         
     | 
| 144 | 
         
            +
                                cfg.preprocess.processed_dir,
         
     | 
| 145 | 
         
            +
                                dataset,
         
     | 
| 146 | 
         
            +
                                cfg.preprocess.label_dir,
         
     | 
| 147 | 
         
            +
                                uid + ".npy",
         
     | 
| 148 | 
         
            +
                            )
         
     | 
| 149 | 
         
            +
                    elif cfg.preprocess.use_one_hot:
         
     | 
| 150 | 
         
            +
                        self.utt2one_hot_path = {}
         
     | 
| 151 | 
         
            +
                        for utt_info in self.metadata:
         
     | 
| 152 | 
         
            +
                            dataset = utt_info["Dataset"]
         
     | 
| 153 | 
         
            +
                            uid = utt_info["Uid"]
         
     | 
| 154 | 
         
            +
                            utt = "{}_{}".format(dataset, uid)
         
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
                            self.utt2one_hot_path[utt] = os.path.join(
         
     | 
| 157 | 
         
            +
                                cfg.preprocess.processed_dir,
         
     | 
| 158 | 
         
            +
                                dataset,
         
     | 
| 159 | 
         
            +
                                cfg.preprocess.one_hot_dir,
         
     | 
| 160 | 
         
            +
                                uid + ".npy",
         
     | 
| 161 | 
         
            +
                            )
         
     | 
| 162 | 
         
            +
             
     | 
| 163 | 
         
            +
                    if cfg.preprocess.use_text or cfg.preprocess.use_phone:
         
     | 
| 164 | 
         
            +
                        self.utt2seq = {}
         
     | 
| 165 | 
         
            +
                        for utt_info in self.metadata:
         
     | 
| 166 | 
         
            +
                            dataset = utt_info["Dataset"]
         
     | 
| 167 | 
         
            +
                            uid = utt_info["Uid"]
         
     | 
| 168 | 
         
            +
                            utt = "{}_{}".format(dataset, uid)
         
     | 
| 169 | 
         
            +
             
     | 
| 170 | 
         
            +
                            if cfg.preprocess.use_text:
         
     | 
| 171 | 
         
            +
                                text = utt_info["Text"]
         
     | 
| 172 | 
         
            +
                                sequence = text_to_sequence(text, cfg.preprocess.text_cleaners)
         
     | 
| 173 | 
         
            +
                            elif cfg.preprocess.use_phone:
         
     | 
| 174 | 
         
            +
                                # load phoneme squence from phone file
         
     | 
| 175 | 
         
            +
                                phone_path = os.path.join(
         
     | 
| 176 | 
         
            +
                                    processed_data_dir, cfg.preprocess.phone_dir, uid + ".phone"
         
     | 
| 177 | 
         
            +
                                )
         
     | 
| 178 | 
         
            +
                                with open(phone_path, "r") as fin:
         
     | 
| 179 | 
         
            +
                                    phones = fin.readlines()
         
     | 
| 180 | 
         
            +
                                    assert len(phones) == 1
         
     | 
| 181 | 
         
            +
                                    phones = phones[0].strip()
         
     | 
| 182 | 
         
            +
                                phones_seq = phones.split(" ")
         
     | 
| 183 | 
         
            +
             
     | 
| 184 | 
         
            +
                                phon_id_collator = phoneIDCollation(cfg, dataset=dataset)
         
     | 
| 185 | 
         
            +
                                sequence = phon_id_collator.get_phone_id_sequence(cfg, phones_seq)
         
     | 
| 186 | 
         
            +
             
     | 
| 187 | 
         
            +
                            self.utt2seq[utt] = sequence
         
     | 
| 188 | 
         
            +
             
     | 
| 189 | 
         
            +
                def get_metadata(self):
         
     | 
| 190 | 
         
            +
                    with open(self.metafile_path, "r", encoding="utf-8") as f:
         
     | 
| 191 | 
         
            +
                        metadata = json.load(f)
         
     | 
| 192 | 
         
            +
             
     | 
| 193 | 
         
            +
                    return metadata
         
     | 
| 194 | 
         
            +
             
     | 
| 195 | 
         
            +
                def get_dataset_name(self):
         
     | 
| 196 | 
         
            +
                    return self.metadata[0]["Dataset"]
         
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
            +
                def __getitem__(self, index):
         
     | 
| 199 | 
         
            +
                    utt_info = self.metadata[index]
         
     | 
| 200 | 
         
            +
             
     | 
| 201 | 
         
            +
                    dataset = utt_info["Dataset"]
         
     | 
| 202 | 
         
            +
                    uid = utt_info["Uid"]
         
     | 
| 203 | 
         
            +
                    utt = "{}_{}".format(dataset, uid)
         
     | 
| 204 | 
         
            +
             
     | 
| 205 | 
         
            +
                    single_feature = dict()
         
     | 
| 206 | 
         
            +
             
     | 
| 207 | 
         
            +
                    if self.cfg.preprocess.use_spkid:
         
     | 
| 208 | 
         
            +
                        single_feature["spk_id"] = np.array(
         
     | 
| 209 | 
         
            +
                            [self.spk2id[self.utt2spk[utt]]], dtype=np.int32
         
     | 
| 210 | 
         
            +
                        )
         
     | 
| 211 | 
         
            +
             
     | 
| 212 | 
         
            +
                    if self.cfg.preprocess.use_mel:
         
     | 
| 213 | 
         
            +
                        mel = np.load(self.utt2mel_path[utt])
         
     | 
| 214 | 
         
            +
                        assert mel.shape[0] == self.cfg.preprocess.n_mel  # [n_mels, T]
         
     | 
| 215 | 
         
            +
                        if self.cfg.preprocess.use_min_max_norm_mel:
         
     | 
| 216 | 
         
            +
                            # do mel norm
         
     | 
| 217 | 
         
            +
                            mel = cal_normalized_mel(mel, utt_info["Dataset"], self.cfg.preprocess)
         
     | 
| 218 | 
         
            +
             
     | 
| 219 | 
         
            +
                        if "target_len" not in single_feature.keys():
         
     | 
| 220 | 
         
            +
                            single_feature["target_len"] = mel.shape[1]
         
     | 
| 221 | 
         
            +
                        single_feature["mel"] = mel.T  # [T, n_mels]
         
     | 
| 222 | 
         
            +
             
     | 
| 223 | 
         
            +
                    if self.cfg.preprocess.use_linear:
         
     | 
| 224 | 
         
            +
                        linear = np.load(self.utt2linear_path[utt])
         
     | 
| 225 | 
         
            +
                        if "target_len" not in single_feature.keys():
         
     | 
| 226 | 
         
            +
                            single_feature["target_len"] = linear.shape[1]
         
     | 
| 227 | 
         
            +
                        single_feature["linear"] = linear.T  # [T, n_linear]
         
     | 
| 228 | 
         
            +
             
     | 
| 229 | 
         
            +
                    if self.cfg.preprocess.use_frame_pitch:
         
     | 
| 230 | 
         
            +
                        frame_pitch_path = self.utt2frame_pitch_path[utt]
         
     | 
| 231 | 
         
            +
                        frame_pitch = np.load(frame_pitch_path)
         
     | 
| 232 | 
         
            +
                        if "target_len" not in single_feature.keys():
         
     | 
| 233 | 
         
            +
                            single_feature["target_len"] = len(frame_pitch)
         
     | 
| 234 | 
         
            +
                        aligned_frame_pitch = align_length(
         
     | 
| 235 | 
         
            +
                            frame_pitch, single_feature["target_len"]
         
     | 
| 236 | 
         
            +
                        )
         
     | 
| 237 | 
         
            +
                        single_feature["frame_pitch"] = aligned_frame_pitch
         
     | 
| 238 | 
         
            +
             
     | 
| 239 | 
         
            +
                        if self.cfg.preprocess.use_uv:
         
     | 
| 240 | 
         
            +
                            frame_uv_path = self.utt2uv_path[utt]
         
     | 
| 241 | 
         
            +
                            frame_uv = np.load(frame_uv_path)
         
     | 
| 242 | 
         
            +
                            aligned_frame_uv = align_length(frame_uv, single_feature["target_len"])
         
     | 
| 243 | 
         
            +
                            aligned_frame_uv = [
         
     | 
| 244 | 
         
            +
                                0 if frame_uv else 1 for frame_uv in aligned_frame_uv
         
     | 
| 245 | 
         
            +
                            ]
         
     | 
| 246 | 
         
            +
                            aligned_frame_uv = np.array(aligned_frame_uv)
         
     | 
| 247 | 
         
            +
                            single_feature["frame_uv"] = aligned_frame_uv
         
     | 
| 248 | 
         
            +
             
     | 
| 249 | 
         
            +
                    if self.cfg.preprocess.use_frame_energy:
         
     | 
| 250 | 
         
            +
                        frame_energy_path = self.utt2frame_energy_path[utt]
         
     | 
| 251 | 
         
            +
                        frame_energy = np.load(frame_energy_path)
         
     | 
| 252 | 
         
            +
                        if "target_len" not in single_feature.keys():
         
     | 
| 253 | 
         
            +
                            single_feature["target_len"] = len(frame_energy)
         
     | 
| 254 | 
         
            +
                        aligned_frame_energy = align_length(
         
     | 
| 255 | 
         
            +
                            frame_energy, single_feature["target_len"]
         
     | 
| 256 | 
         
            +
                        )
         
     | 
| 257 | 
         
            +
                        single_feature["frame_energy"] = aligned_frame_energy
         
     | 
| 258 | 
         
            +
             
     | 
| 259 | 
         
            +
                    if self.cfg.preprocess.use_audio:
         
     | 
| 260 | 
         
            +
                        audio = np.load(self.utt2audio_path[utt])
         
     | 
| 261 | 
         
            +
                        single_feature["audio"] = audio
         
     | 
| 262 | 
         
            +
                        single_feature["audio_len"] = audio.shape[0]
         
     | 
| 263 | 
         
            +
             
     | 
| 264 | 
         
            +
                    if self.cfg.preprocess.use_phone or self.cfg.preprocess.use_text:
         
     | 
| 265 | 
         
            +
                        single_feature["phone_seq"] = np.array(self.utt2seq[utt])
         
     | 
| 266 | 
         
            +
                        single_feature["phone_len"] = len(self.utt2seq[utt])
         
     | 
| 267 | 
         
            +
             
     | 
| 268 | 
         
            +
                    return single_feature
         
     | 
| 269 | 
         
            +
             
     | 
| 270 | 
         
            +
                def __len__(self):
         
     | 
| 271 | 
         
            +
                    return len(self.metadata)
         
     | 
| 272 | 
         
            +
             
     | 
| 273 | 
         
            +
             
     | 
| 274 | 
         
            +
            class BaseOfflineCollator(object):
         
     | 
| 275 | 
         
            +
                """Zero-pads model inputs and targets based on number of frames per step"""
         
     | 
| 276 | 
         
            +
             
     | 
| 277 | 
         
            +
                def __init__(self, cfg):
         
     | 
| 278 | 
         
            +
                    self.cfg = cfg
         
     | 
| 279 | 
         
            +
             
     | 
| 280 | 
         
            +
                def __call__(self, batch):
         
     | 
| 281 | 
         
            +
                    packed_batch_features = dict()
         
     | 
| 282 | 
         
            +
             
     | 
| 283 | 
         
            +
                    # mel: [b, T, n_mels]
         
     | 
| 284 | 
         
            +
                    # frame_pitch, frame_energy: [1, T]
         
     | 
| 285 | 
         
            +
                    # target_len: [b]
         
     | 
| 286 | 
         
            +
                    # spk_id: [b, 1]
         
     | 
| 287 | 
         
            +
                    # mask: [b, T, 1]
         
     | 
| 288 | 
         
            +
             
     | 
| 289 | 
         
            +
                    for key in batch[0].keys():
         
     | 
| 290 | 
         
            +
                        if key == "target_len":
         
     | 
| 291 | 
         
            +
                            packed_batch_features["target_len"] = torch.LongTensor(
         
     | 
| 292 | 
         
            +
                                [b["target_len"] for b in batch]
         
     | 
| 293 | 
         
            +
                            )
         
     | 
| 294 | 
         
            +
                            masks = [
         
     | 
| 295 | 
         
            +
                                torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch
         
     | 
| 296 | 
         
            +
                            ]
         
     | 
| 297 | 
         
            +
                            packed_batch_features["mask"] = pad_sequence(
         
     | 
| 298 | 
         
            +
                                masks, batch_first=True, padding_value=0
         
     | 
| 299 | 
         
            +
                            )
         
     | 
| 300 | 
         
            +
                        elif key == "phone_len":
         
     | 
| 301 | 
         
            +
                            packed_batch_features["phone_len"] = torch.LongTensor(
         
     | 
| 302 | 
         
            +
                                [b["phone_len"] for b in batch]
         
     | 
| 303 | 
         
            +
                            )
         
     | 
| 304 | 
         
            +
                            masks = [
         
     | 
| 305 | 
         
            +
                                torch.ones((b["phone_len"], 1), dtype=torch.long) for b in batch
         
     | 
| 306 | 
         
            +
                            ]
         
     | 
| 307 | 
         
            +
                            packed_batch_features["phn_mask"] = pad_sequence(
         
     | 
| 308 | 
         
            +
                                masks, batch_first=True, padding_value=0
         
     | 
| 309 | 
         
            +
                            )
         
     | 
| 310 | 
         
            +
                        elif key == "audio_len":
         
     | 
| 311 | 
         
            +
                            packed_batch_features["audio_len"] = torch.LongTensor(
         
     | 
| 312 | 
         
            +
                                [b["audio_len"] for b in batch]
         
     | 
| 313 | 
         
            +
                            )
         
     | 
| 314 | 
         
            +
                            masks = [
         
     | 
| 315 | 
         
            +
                                torch.ones((b["audio_len"], 1), dtype=torch.long) for b in batch
         
     | 
| 316 | 
         
            +
                            ]
         
     | 
| 317 | 
         
            +
                        else:
         
     | 
| 318 | 
         
            +
                            values = [torch.from_numpy(b[key]) for b in batch]
         
     | 
| 319 | 
         
            +
                            packed_batch_features[key] = pad_sequence(
         
     | 
| 320 | 
         
            +
                                values, batch_first=True, padding_value=0
         
     | 
| 321 | 
         
            +
                            )
         
     | 
| 322 | 
         
            +
                    return packed_batch_features
         
     | 
| 323 | 
         
            +
             
     | 
| 324 | 
         
            +
             
     | 
| 325 | 
         
            +
            class BaseOnlineDataset(torch.utils.data.Dataset):
         
     | 
| 326 | 
         
            +
                def __init__(self, cfg, dataset, is_valid=False):
         
     | 
| 327 | 
         
            +
                    """
         
     | 
| 328 | 
         
            +
                    Args:
         
     | 
| 329 | 
         
            +
                        cfg: config
         
     | 
| 330 | 
         
            +
                        dataset: dataset name
         
     | 
| 331 | 
         
            +
                        is_valid: whether to use train or valid dataset
         
     | 
| 332 | 
         
            +
                    """
         
     | 
| 333 | 
         
            +
                    assert isinstance(dataset, str)
         
     | 
| 334 | 
         
            +
             
     | 
| 335 | 
         
            +
                    self.cfg = cfg
         
     | 
| 336 | 
         
            +
                    self.sample_rate = cfg.preprocess.sample_rate
         
     | 
| 337 | 
         
            +
                    self.hop_size = self.cfg.preprocess.hop_size
         
     | 
| 338 | 
         
            +
             
     | 
| 339 | 
         
            +
                    processed_data_dir = os.path.join(cfg.preprocess.processed_dir, dataset)
         
     | 
| 340 | 
         
            +
                    meta_file = cfg.preprocess.valid_file if is_valid else cfg.preprocess.train_file
         
     | 
| 341 | 
         
            +
                    self.metafile_path = os.path.join(processed_data_dir, meta_file)
         
     | 
| 342 | 
         
            +
                    self.metadata = self.get_metadata()
         
     | 
| 343 | 
         
            +
             
     | 
| 344 | 
         
            +
                    """
         
     | 
| 345 | 
         
            +
                    load spk2id and utt2spk from json file
         
     | 
| 346 | 
         
            +
                        spk2id: {spk1: 0, spk2: 1, ...}
         
     | 
| 347 | 
         
            +
                        utt2spk: {dataset_uid: spk1, ...}
         
     | 
| 348 | 
         
            +
                    """
         
     | 
| 349 | 
         
            +
                    if cfg.preprocess.use_spkid:
         
     | 
| 350 | 
         
            +
                        spk2id_path = os.path.join(processed_data_dir, cfg.preprocess.spk2id)
         
     | 
| 351 | 
         
            +
                        with open(spk2id_path, "r") as f:
         
     | 
| 352 | 
         
            +
                            self.spk2id = json.load(f)
         
     | 
| 353 | 
         
            +
             
     | 
| 354 | 
         
            +
                        utt2spk_path = os.path.join(processed_data_dir, cfg.preprocess.utt2spk)
         
     | 
| 355 | 
         
            +
                        self.utt2spk = dict()
         
     | 
| 356 | 
         
            +
                        with open(utt2spk_path, "r") as f:
         
     | 
| 357 | 
         
            +
                            for line in f.readlines():
         
     | 
| 358 | 
         
            +
                                utt, spk = line.strip().split("\t")
         
     | 
| 359 | 
         
            +
                                self.utt2spk[utt] = spk
         
     | 
| 360 | 
         
            +
             
     | 
| 361 | 
         
            +
                def get_metadata(self):
         
     | 
| 362 | 
         
            +
                    with open(self.metafile_path, "r", encoding="utf-8") as f:
         
     | 
| 363 | 
         
            +
                        metadata = json.load(f)
         
     | 
| 364 | 
         
            +
             
     | 
| 365 | 
         
            +
                    return metadata
         
     | 
| 366 | 
         
            +
             
     | 
| 367 | 
         
            +
                def get_dataset_name(self):
         
     | 
| 368 | 
         
            +
                    return self.metadata[0]["Dataset"]
         
     | 
| 369 | 
         
            +
             
     | 
| 370 | 
         
            +
                def __getitem__(self, index):
         
     | 
| 371 | 
         
            +
                    """
         
     | 
| 372 | 
         
            +
                    single_feature:
         
     | 
| 373 | 
         
            +
                        wav: (T)
         
     | 
| 374 | 
         
            +
                        wav_len: int
         
     | 
| 375 | 
         
            +
                        target_len: int
         
     | 
| 376 | 
         
            +
                        mask: (n_frames, 1)
         
     | 
| 377 | 
         
            +
                        spk_id: (1)
         
     | 
| 378 | 
         
            +
                    """
         
     | 
| 379 | 
         
            +
                    utt_item = self.metadata[index]
         
     | 
| 380 | 
         
            +
             
     | 
| 381 | 
         
            +
                    wav_path = utt_item["Path"]
         
     | 
| 382 | 
         
            +
                    wav, _ = librosa.load(wav_path, sr=self.sample_rate)
         
     | 
| 383 | 
         
            +
                    # wav: (T)
         
     | 
| 384 | 
         
            +
                    wav = torch.as_tensor(wav, dtype=torch.float32)
         
     | 
| 385 | 
         
            +
                    wav_len = len(wav)
         
     | 
| 386 | 
         
            +
                    # mask: (n_frames, 1)
         
     | 
| 387 | 
         
            +
                    frame_len = wav_len // self.hop_size
         
     | 
| 388 | 
         
            +
                    mask = torch.ones(frame_len, 1, dtype=torch.long)
         
     | 
| 389 | 
         
            +
             
     | 
| 390 | 
         
            +
                    single_feature = {
         
     | 
| 391 | 
         
            +
                        "wav": wav,
         
     | 
| 392 | 
         
            +
                        "wav_len": wav_len,
         
     | 
| 393 | 
         
            +
                        "target_len": frame_len,
         
     | 
| 394 | 
         
            +
                        "mask": mask,
         
     | 
| 395 | 
         
            +
                    }
         
     | 
| 396 | 
         
            +
             
     | 
| 397 | 
         
            +
                    if self.cfg.preprocess.use_spkid:
         
     | 
| 398 | 
         
            +
                        utt = "{}_{}".format(utt_item["Dataset"], utt_item["Uid"])
         
     | 
| 399 | 
         
            +
                        single_feature["spk_id"] = torch.tensor(
         
     | 
| 400 | 
         
            +
                            [self.spk2id[self.utt2spk[utt]]], dtype=torch.int32
         
     | 
| 401 | 
         
            +
                        )
         
     | 
| 402 | 
         
            +
             
     | 
| 403 | 
         
            +
                    return single_feature
         
     | 
| 404 | 
         
            +
             
     | 
| 405 | 
         
            +
                def __len__(self):
         
     | 
| 406 | 
         
            +
                    return len(self.metadata)
         
     | 
| 407 | 
         
            +
             
     | 
| 408 | 
         
            +
             
     | 
| 409 | 
         
            +
            class BaseOnlineCollator(object):
         
     | 
| 410 | 
         
            +
                """Zero-pads model inputs and targets based on number of frames per step (For on-the-fly features extraction, whose iterative item contains only wavs)"""
         
     | 
| 411 | 
         
            +
             
     | 
| 412 | 
         
            +
                def __init__(self, cfg):
         
     | 
| 413 | 
         
            +
                    self.cfg = cfg
         
     | 
| 414 | 
         
            +
             
     | 
| 415 | 
         
            +
                def __call__(self, batch):
         
     | 
| 416 | 
         
            +
                    """
         
     | 
| 417 | 
         
            +
                    BaseOnlineDataset.__getitem__:
         
     | 
| 418 | 
         
            +
                        wav: (T,)
         
     | 
| 419 | 
         
            +
                        wav_len: int
         
     | 
| 420 | 
         
            +
                        target_len: int
         
     | 
| 421 | 
         
            +
                        mask: (n_frames, 1)
         
     | 
| 422 | 
         
            +
                        spk_id: (1)
         
     | 
| 423 | 
         
            +
             
     | 
| 424 | 
         
            +
                    Returns:
         
     | 
| 425 | 
         
            +
                        wav: (B, T), torch.float32
         
     | 
| 426 | 
         
            +
                        wav_len: (B), torch.long
         
     | 
| 427 | 
         
            +
                        target_len: (B), torch.long
         
     | 
| 428 | 
         
            +
                        mask: (B, n_frames, 1), torch.long
         
     | 
| 429 | 
         
            +
                        spk_id: (B, 1), torch.int32
         
     | 
| 430 | 
         
            +
                    """
         
     | 
| 431 | 
         
            +
                    packed_batch_features = dict()
         
     | 
| 432 | 
         
            +
             
     | 
| 433 | 
         
            +
                    for key in batch[0].keys():
         
     | 
| 434 | 
         
            +
                        if key in ["wav_len", "target_len"]:
         
     | 
| 435 | 
         
            +
                            packed_batch_features[key] = torch.LongTensor([b[key] for b in batch])
         
     | 
| 436 | 
         
            +
                        else:
         
     | 
| 437 | 
         
            +
                            packed_batch_features[key] = pad_sequence(
         
     | 
| 438 | 
         
            +
                                [b[key] for b in batch], batch_first=True, padding_value=0
         
     | 
| 439 | 
         
            +
                            )
         
     | 
| 440 | 
         
            +
                    return packed_batch_features
         
     | 
| 441 | 
         
            +
             
     | 
| 442 | 
         
            +
             
     | 
| 443 | 
         
            +
            class BaseTestDataset(torch.utils.data.Dataset):
         
     | 
| 444 | 
         
            +
                def __init__(self, cfg, args):
         
     | 
| 445 | 
         
            +
                    raise NotImplementedError
         
     | 
| 446 | 
         
            +
             
     | 
| 447 | 
         
            +
                def get_metadata(self):
         
     | 
| 448 | 
         
            +
                    raise NotImplementedError
         
     | 
| 449 | 
         
            +
             
     | 
| 450 | 
         
            +
                def __getitem__(self, index):
         
     | 
| 451 | 
         
            +
                    raise NotImplementedError
         
     | 
| 452 | 
         
            +
             
     | 
| 453 | 
         
            +
                def __len__(self):
         
     | 
| 454 | 
         
            +
                    return len(self.metadata)
         
     | 
| 455 | 
         
            +
             
     | 
| 456 | 
         
            +
             
     | 
| 457 | 
         
            +
            class BaseTestCollator(object):
         
     | 
| 458 | 
         
            +
                """Zero-pads model inputs and targets based on number of frames per step"""
         
     | 
| 459 | 
         
            +
             
     | 
| 460 | 
         
            +
                def __init__(self, cfg):
         
     | 
| 461 | 
         
            +
                    raise NotImplementedError
         
     | 
| 462 | 
         
            +
             
     | 
| 463 | 
         
            +
                def __call__(self, batch):
         
     | 
| 464 | 
         
            +
                    raise NotImplementedError
         
     | 
    	
        models/base/base_inference.py
    ADDED
    
    | 
         @@ -0,0 +1,220 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2023 Amphion.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import argparse
         
     | 
| 7 | 
         
            +
            import os
         
     | 
| 8 | 
         
            +
            import re
         
     | 
| 9 | 
         
            +
            import time
         
     | 
| 10 | 
         
            +
            from pathlib import Path
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            import torch
         
     | 
| 13 | 
         
            +
            from torch.utils.data import DataLoader
         
     | 
| 14 | 
         
            +
            from tqdm import tqdm
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            from models.vocoders.vocoder_inference import synthesis
         
     | 
| 17 | 
         
            +
            from torch.utils.data import DataLoader
         
     | 
| 18 | 
         
            +
            from utils.util import set_all_random_seed
         
     | 
| 19 | 
         
            +
            from utils.util import load_config
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            def parse_vocoder(vocoder_dir):
         
     | 
| 23 | 
         
            +
                r"""Parse vocoder config"""
         
     | 
| 24 | 
         
            +
                vocoder_dir = os.path.abspath(vocoder_dir)
         
     | 
| 25 | 
         
            +
                ckpt_list = [ckpt for ckpt in Path(vocoder_dir).glob("*.pt")]
         
     | 
| 26 | 
         
            +
                ckpt_list.sort(key=lambda x: int(x.stem), reverse=True)
         
     | 
| 27 | 
         
            +
                ckpt_path = str(ckpt_list[0])
         
     | 
| 28 | 
         
            +
                vocoder_cfg = load_config(os.path.join(vocoder_dir, "args.json"), lowercase=True)
         
     | 
| 29 | 
         
            +
                vocoder_cfg.model.bigvgan = vocoder_cfg.vocoder
         
     | 
| 30 | 
         
            +
                return vocoder_cfg, ckpt_path
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
            class BaseInference(object):
         
     | 
| 34 | 
         
            +
                def __init__(self, cfg, args):
         
     | 
| 35 | 
         
            +
                    self.cfg = cfg
         
     | 
| 36 | 
         
            +
                    self.args = args
         
     | 
| 37 | 
         
            +
                    self.model_type = cfg.model_type
         
     | 
| 38 | 
         
            +
                    self.avg_rtf = list()
         
     | 
| 39 | 
         
            +
                    set_all_random_seed(10086)
         
     | 
| 40 | 
         
            +
                    os.makedirs(args.output_dir, exist_ok=True)
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                    if torch.cuda.is_available():
         
     | 
| 43 | 
         
            +
                        self.device = torch.device("cuda")
         
     | 
| 44 | 
         
            +
                    else:
         
     | 
| 45 | 
         
            +
                        self.device = torch.device("cpu")
         
     | 
| 46 | 
         
            +
                        torch.set_num_threads(10)  # inference on 1 core cpu.
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                    # Load acoustic model
         
     | 
| 49 | 
         
            +
                    self.model = self.create_model().to(self.device)
         
     | 
| 50 | 
         
            +
                    state_dict = self.load_state_dict()
         
     | 
| 51 | 
         
            +
                    self.load_model(state_dict)
         
     | 
| 52 | 
         
            +
                    self.model.eval()
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                    # Load vocoder model if necessary
         
     | 
| 55 | 
         
            +
                    if self.args.checkpoint_dir_vocoder is not None:
         
     | 
| 56 | 
         
            +
                        self.get_vocoder_info()
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                def create_model(self):
         
     | 
| 59 | 
         
            +
                    raise NotImplementedError
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                def load_state_dict(self):
         
     | 
| 62 | 
         
            +
                    self.checkpoint_file = self.args.checkpoint_file
         
     | 
| 63 | 
         
            +
                    if self.checkpoint_file is None:
         
     | 
| 64 | 
         
            +
                        assert self.args.checkpoint_dir is not None
         
     | 
| 65 | 
         
            +
                        checkpoint_path = os.path.join(self.args.checkpoint_dir, "checkpoint")
         
     | 
| 66 | 
         
            +
                        checkpoint_filename = open(checkpoint_path).readlines()[-1].strip()
         
     | 
| 67 | 
         
            +
                        self.checkpoint_file = os.path.join(
         
     | 
| 68 | 
         
            +
                            self.args.checkpoint_dir, checkpoint_filename
         
     | 
| 69 | 
         
            +
                        )
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                    self.checkpoint_dir = os.path.split(self.checkpoint_file)[0]
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                    print("Restore acoustic model from {}".format(self.checkpoint_file))
         
     | 
| 74 | 
         
            +
                    raw_state_dict = torch.load(self.checkpoint_file, map_location=self.device)
         
     | 
| 75 | 
         
            +
                    self.am_restore_step = re.findall(r"step-(.+?)_loss", self.checkpoint_file)[0]
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
                    return raw_state_dict
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
                def load_model(self, model):
         
     | 
| 80 | 
         
            +
                    raise NotImplementedError
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                def get_vocoder_info(self):
         
     | 
| 83 | 
         
            +
                    self.checkpoint_dir_vocoder = self.args.checkpoint_dir_vocoder
         
     | 
| 84 | 
         
            +
                    self.vocoder_cfg = os.path.join(
         
     | 
| 85 | 
         
            +
                        os.path.dirname(self.checkpoint_dir_vocoder), "args.json"
         
     | 
| 86 | 
         
            +
                    )
         
     | 
| 87 | 
         
            +
                    self.cfg.vocoder = load_config(self.vocoder_cfg, lowercase=True)
         
     | 
| 88 | 
         
            +
                    self.vocoder_tag = self.checkpoint_dir_vocoder.split("/")[-2].split(":")[-1]
         
     | 
| 89 | 
         
            +
                    self.vocoder_steps = self.checkpoint_dir_vocoder.split("/")[-1].split(".")[0]
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                def build_test_utt_data(self):
         
     | 
| 92 | 
         
            +
                    raise NotImplementedError
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                def build_testdata_loader(self, args, target_speaker=None):
         
     | 
| 95 | 
         
            +
                    datasets, collate = self.build_test_dataset()
         
     | 
| 96 | 
         
            +
                    self.test_dataset = datasets(self.cfg, args, target_speaker)
         
     | 
| 97 | 
         
            +
                    self.test_collate = collate(self.cfg)
         
     | 
| 98 | 
         
            +
                    self.test_batch_size = min(
         
     | 
| 99 | 
         
            +
                        self.cfg.train.batch_size, len(self.test_dataset.metadata)
         
     | 
| 100 | 
         
            +
                    )
         
     | 
| 101 | 
         
            +
                    test_loader = DataLoader(
         
     | 
| 102 | 
         
            +
                        self.test_dataset,
         
     | 
| 103 | 
         
            +
                        collate_fn=self.test_collate,
         
     | 
| 104 | 
         
            +
                        num_workers=self.args.num_workers,
         
     | 
| 105 | 
         
            +
                        batch_size=self.test_batch_size,
         
     | 
| 106 | 
         
            +
                        shuffle=False,
         
     | 
| 107 | 
         
            +
                    )
         
     | 
| 108 | 
         
            +
                    return test_loader
         
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
                def inference_each_batch(self, batch_data):
         
     | 
| 111 | 
         
            +
                    raise NotImplementedError
         
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
                def inference_for_batches(self, args, target_speaker=None):
         
     | 
| 114 | 
         
            +
                    ###### Construct test_batch ######
         
     | 
| 115 | 
         
            +
                    loader = self.build_testdata_loader(args, target_speaker)
         
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
                    n_batch = len(loader)
         
     | 
| 118 | 
         
            +
                    now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
         
     | 
| 119 | 
         
            +
                    print(
         
     | 
| 120 | 
         
            +
                        "Model eval time: {}, batch_size = {}, n_batch = {}".format(
         
     | 
| 121 | 
         
            +
                            now, self.test_batch_size, n_batch
         
     | 
| 122 | 
         
            +
                        )
         
     | 
| 123 | 
         
            +
                    )
         
     | 
| 124 | 
         
            +
                    self.model.eval()
         
     | 
| 125 | 
         
            +
             
     | 
| 126 | 
         
            +
                    ###### Inference for each batch ######
         
     | 
| 127 | 
         
            +
                    pred_res = []
         
     | 
| 128 | 
         
            +
                    with torch.no_grad():
         
     | 
| 129 | 
         
            +
                        for i, batch_data in enumerate(loader if n_batch == 1 else tqdm(loader)):
         
     | 
| 130 | 
         
            +
                            # Put the data to device
         
     | 
| 131 | 
         
            +
                            for k, v in batch_data.items():
         
     | 
| 132 | 
         
            +
                                batch_data[k] = batch_data[k].to(self.device)
         
     | 
| 133 | 
         
            +
             
     | 
| 134 | 
         
            +
                            y_pred, stats = self.inference_each_batch(batch_data)
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
                            pred_res += y_pred
         
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
                    return pred_res
         
     | 
| 139 | 
         
            +
             
     | 
| 140 | 
         
            +
                def inference(self, feature):
         
     | 
| 141 | 
         
            +
                    raise NotImplementedError
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
                def synthesis_by_vocoder(self, pred):
         
     | 
| 144 | 
         
            +
                    audios_pred = synthesis(
         
     | 
| 145 | 
         
            +
                        self.vocoder_cfg,
         
     | 
| 146 | 
         
            +
                        self.checkpoint_dir_vocoder,
         
     | 
| 147 | 
         
            +
                        len(pred),
         
     | 
| 148 | 
         
            +
                        pred,
         
     | 
| 149 | 
         
            +
                    )
         
     | 
| 150 | 
         
            +
                    return audios_pred
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
                def __call__(self, utt):
         
     | 
| 153 | 
         
            +
                    feature = self.build_test_utt_data(utt)
         
     | 
| 154 | 
         
            +
                    start_time = time.time()
         
     | 
| 155 | 
         
            +
                    with torch.no_grad():
         
     | 
| 156 | 
         
            +
                        outputs = self.inference(feature)[0]
         
     | 
| 157 | 
         
            +
                    time_used = time.time() - start_time
         
     | 
| 158 | 
         
            +
                    rtf = time_used / (
         
     | 
| 159 | 
         
            +
                        outputs.shape[1]
         
     | 
| 160 | 
         
            +
                        * self.cfg.preprocess.hop_size
         
     | 
| 161 | 
         
            +
                        / self.cfg.preprocess.sample_rate
         
     | 
| 162 | 
         
            +
                    )
         
     | 
| 163 | 
         
            +
                    print("Time used: {:.3f}, RTF: {:.4f}".format(time_used, rtf))
         
     | 
| 164 | 
         
            +
                    self.avg_rtf.append(rtf)
         
     | 
| 165 | 
         
            +
                    audios = outputs.cpu().squeeze().numpy().reshape(-1, 1)
         
     | 
| 166 | 
         
            +
                    return audios
         
     | 
| 167 | 
         
            +
             
     | 
| 168 | 
         
            +
             
     | 
| 169 | 
         
            +
            def base_parser():
         
     | 
| 170 | 
         
            +
                parser = argparse.ArgumentParser()
         
     | 
| 171 | 
         
            +
                parser.add_argument(
         
     | 
| 172 | 
         
            +
                    "--config", default="config.json", help="json files for configurations."
         
     | 
| 173 | 
         
            +
                )
         
     | 
| 174 | 
         
            +
                parser.add_argument("--use_ddp_inference", default=False)
         
     | 
| 175 | 
         
            +
                parser.add_argument("--n_workers", default=1, type=int)
         
     | 
| 176 | 
         
            +
                parser.add_argument("--local_rank", default=-1, type=int)
         
     | 
| 177 | 
         
            +
                parser.add_argument(
         
     | 
| 178 | 
         
            +
                    "--batch_size", default=1, type=int, help="Batch size for inference"
         
     | 
| 179 | 
         
            +
                )
         
     | 
| 180 | 
         
            +
                parser.add_argument(
         
     | 
| 181 | 
         
            +
                    "--num_workers",
         
     | 
| 182 | 
         
            +
                    default=1,
         
     | 
| 183 | 
         
            +
                    type=int,
         
     | 
| 184 | 
         
            +
                    help="Worker number for inference dataloader",
         
     | 
| 185 | 
         
            +
                )
         
     | 
| 186 | 
         
            +
                parser.add_argument(
         
     | 
| 187 | 
         
            +
                    "--checkpoint_dir",
         
     | 
| 188 | 
         
            +
                    type=str,
         
     | 
| 189 | 
         
            +
                    default=None,
         
     | 
| 190 | 
         
            +
                    help="Checkpoint dir including model file and configuration",
         
     | 
| 191 | 
         
            +
                )
         
     | 
| 192 | 
         
            +
                parser.add_argument(
         
     | 
| 193 | 
         
            +
                    "--checkpoint_file", help="checkpoint file", type=str, default=None
         
     | 
| 194 | 
         
            +
                )
         
     | 
| 195 | 
         
            +
                parser.add_argument(
         
     | 
| 196 | 
         
            +
                    "--test_list", help="test utterance list for testing", type=str, default=None
         
     | 
| 197 | 
         
            +
                )
         
     | 
| 198 | 
         
            +
                parser.add_argument(
         
     | 
| 199 | 
         
            +
                    "--checkpoint_dir_vocoder",
         
     | 
| 200 | 
         
            +
                    help="Vocoder's checkpoint dir including model file and configuration",
         
     | 
| 201 | 
         
            +
                    type=str,
         
     | 
| 202 | 
         
            +
                    default=None,
         
     | 
| 203 | 
         
            +
                )
         
     | 
| 204 | 
         
            +
                parser.add_argument(
         
     | 
| 205 | 
         
            +
                    "--output_dir",
         
     | 
| 206 | 
         
            +
                    type=str,
         
     | 
| 207 | 
         
            +
                    default=None,
         
     | 
| 208 | 
         
            +
                    help="Output dir for saving generated results",
         
     | 
| 209 | 
         
            +
                )
         
     | 
| 210 | 
         
            +
                return parser
         
     | 
| 211 | 
         
            +
             
     | 
| 212 | 
         
            +
             
     | 
| 213 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 214 | 
         
            +
                parser = base_parser()
         
     | 
| 215 | 
         
            +
                args = parser.parse_args()
         
     | 
| 216 | 
         
            +
                cfg = load_config(args.config)
         
     | 
| 217 | 
         
            +
             
     | 
| 218 | 
         
            +
                # Build inference
         
     | 
| 219 | 
         
            +
                inference = BaseInference(cfg, args)
         
     | 
| 220 | 
         
            +
                inference()
         
     | 
    	
        models/base/base_sampler.py
    ADDED
    
    | 
         @@ -0,0 +1,157 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2023 Amphion.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import math
         
     | 
| 7 | 
         
            +
            import random
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            from torch.utils.data import ConcatDataset, Dataset
         
     | 
| 10 | 
         
            +
            from torch.utils.data.sampler import (
         
     | 
| 11 | 
         
            +
                BatchSampler,
         
     | 
| 12 | 
         
            +
                RandomSampler,
         
     | 
| 13 | 
         
            +
                Sampler,
         
     | 
| 14 | 
         
            +
                SequentialSampler,
         
     | 
| 15 | 
         
            +
            )
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            class ScheduledSampler(Sampler):
         
     | 
| 19 | 
         
            +
                """A sampler that samples data from a given concat-dataset.
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
                Args:
         
     | 
| 22 | 
         
            +
                    concat_dataset (ConcatDataset): a concatenated dataset consisting of all datasets
         
     | 
| 23 | 
         
            +
                    batch_size (int): batch size
         
     | 
| 24 | 
         
            +
                    holistic_shuffle (bool): whether to shuffle the whole dataset or not
         
     | 
| 25 | 
         
            +
                    logger (logging.Logger): logger to print warning message
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                Usage:
         
     | 
| 28 | 
         
            +
                    For cfg.train.batch_size = 3, cfg.train.holistic_shuffle = False, cfg.train.drop_last = True:
         
     | 
| 29 | 
         
            +
                    >>> list(ScheduledSampler(ConcatDataset([[0, 1, 2], [3, 4, 5], [6, 7, 8]])))
         
     | 
| 30 | 
         
            +
                    [3, 4, 5, 0, 1, 2, 6, 7, 8]
         
     | 
| 31 | 
         
            +
                """
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
                def __init__(
         
     | 
| 34 | 
         
            +
                    self,
         
     | 
| 35 | 
         
            +
                    concat_dataset,
         
     | 
| 36 | 
         
            +
                    batch_size,
         
     | 
| 37 | 
         
            +
                    holistic_shuffle,
         
     | 
| 38 | 
         
            +
                    logger=None,
         
     | 
| 39 | 
         
            +
                    loader_type="train",
         
     | 
| 40 | 
         
            +
                ):
         
     | 
| 41 | 
         
            +
                    if not isinstance(concat_dataset, ConcatDataset):
         
     | 
| 42 | 
         
            +
                        raise ValueError(
         
     | 
| 43 | 
         
            +
                            "concat_dataset must be an instance of ConcatDataset, but got {}".format(
         
     | 
| 44 | 
         
            +
                                type(concat_dataset)
         
     | 
| 45 | 
         
            +
                            )
         
     | 
| 46 | 
         
            +
                        )
         
     | 
| 47 | 
         
            +
                    if not isinstance(batch_size, int):
         
     | 
| 48 | 
         
            +
                        raise ValueError(
         
     | 
| 49 | 
         
            +
                            "batch_size must be an integer, but got {}".format(type(batch_size))
         
     | 
| 50 | 
         
            +
                        )
         
     | 
| 51 | 
         
            +
                    if not isinstance(holistic_shuffle, bool):
         
     | 
| 52 | 
         
            +
                        raise ValueError(
         
     | 
| 53 | 
         
            +
                            "holistic_shuffle must be a boolean, but got {}".format(
         
     | 
| 54 | 
         
            +
                                type(holistic_shuffle)
         
     | 
| 55 | 
         
            +
                            )
         
     | 
| 56 | 
         
            +
                        )
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                    self.concat_dataset = concat_dataset
         
     | 
| 59 | 
         
            +
                    self.batch_size = batch_size
         
     | 
| 60 | 
         
            +
                    self.holistic_shuffle = holistic_shuffle
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                    affected_dataset_name = []
         
     | 
| 63 | 
         
            +
                    affected_dataset_len = []
         
     | 
| 64 | 
         
            +
                    for dataset in concat_dataset.datasets:
         
     | 
| 65 | 
         
            +
                        dataset_len = len(dataset)
         
     | 
| 66 | 
         
            +
                        dataset_name = dataset.get_dataset_name()
         
     | 
| 67 | 
         
            +
                        if dataset_len < batch_size:
         
     | 
| 68 | 
         
            +
                            affected_dataset_name.append(dataset_name)
         
     | 
| 69 | 
         
            +
                            affected_dataset_len.append(dataset_len)
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                    self.type = loader_type
         
     | 
| 72 | 
         
            +
                    for dataset_name, dataset_len in zip(
         
     | 
| 73 | 
         
            +
                        affected_dataset_name, affected_dataset_len
         
     | 
| 74 | 
         
            +
                    ):
         
     | 
| 75 | 
         
            +
                        if not loader_type == "valid":
         
     | 
| 76 | 
         
            +
                            logger.warning(
         
     | 
| 77 | 
         
            +
                                "The {} dataset {} has a length of {}, which is smaller than the batch size {}. This may cause unexpected behavior.".format(
         
     | 
| 78 | 
         
            +
                                    loader_type, dataset_name, dataset_len, batch_size
         
     | 
| 79 | 
         
            +
                                )
         
     | 
| 80 | 
         
            +
                            )
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                def __len__(self):
         
     | 
| 83 | 
         
            +
                    # the number of batches with drop last
         
     | 
| 84 | 
         
            +
                    num_of_batches = sum(
         
     | 
| 85 | 
         
            +
                        [
         
     | 
| 86 | 
         
            +
                            math.floor(len(dataset) / self.batch_size)
         
     | 
| 87 | 
         
            +
                            for dataset in self.concat_dataset.datasets
         
     | 
| 88 | 
         
            +
                        ]
         
     | 
| 89 | 
         
            +
                    )
         
     | 
| 90 | 
         
            +
                    # if samples are not enough for one batch, we don't drop last
         
     | 
| 91 | 
         
            +
                    if self.type == "valid" and num_of_batches < 1:
         
     | 
| 92 | 
         
            +
                        return len(self.concat_dataset)
         
     | 
| 93 | 
         
            +
                    return num_of_batches * self.batch_size
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
                def __iter__(self):
         
     | 
| 96 | 
         
            +
                    iters = []
         
     | 
| 97 | 
         
            +
                    for dataset in self.concat_dataset.datasets:
         
     | 
| 98 | 
         
            +
                        iters.append(
         
     | 
| 99 | 
         
            +
                            SequentialSampler(dataset).__iter__()
         
     | 
| 100 | 
         
            +
                            if not self.holistic_shuffle
         
     | 
| 101 | 
         
            +
                            else RandomSampler(dataset).__iter__()
         
     | 
| 102 | 
         
            +
                        )
         
     | 
| 103 | 
         
            +
                    # e.g. [0, 200, 400]
         
     | 
| 104 | 
         
            +
                    init_indices = [0] + self.concat_dataset.cumulative_sizes[:-1]
         
     | 
| 105 | 
         
            +
                    output_batches = []
         
     | 
| 106 | 
         
            +
                    for dataset_idx in range(len(self.concat_dataset.datasets)):
         
     | 
| 107 | 
         
            +
                        cur_batch = []
         
     | 
| 108 | 
         
            +
                        for idx in iters[dataset_idx]:
         
     | 
| 109 | 
         
            +
                            cur_batch.append(idx + init_indices[dataset_idx])
         
     | 
| 110 | 
         
            +
                            if len(cur_batch) == self.batch_size:
         
     | 
| 111 | 
         
            +
                                output_batches.append(cur_batch)
         
     | 
| 112 | 
         
            +
                                cur_batch = []
         
     | 
| 113 | 
         
            +
                        # if loader_type is valid, we don't need to drop last
         
     | 
| 114 | 
         
            +
                        if self.type == "valid" and len(cur_batch) > 0:
         
     | 
| 115 | 
         
            +
                            output_batches.append(cur_batch)
         
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
                    # force drop last in training
         
     | 
| 118 | 
         
            +
                    random.shuffle(output_batches)
         
     | 
| 119 | 
         
            +
                    output_indices = [item for sublist in output_batches for item in sublist]
         
     | 
| 120 | 
         
            +
                    return iter(output_indices)
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
            def build_samplers(concat_dataset: Dataset, cfg, logger, loader_type):
         
     | 
| 124 | 
         
            +
                sampler = ScheduledSampler(
         
     | 
| 125 | 
         
            +
                    concat_dataset,
         
     | 
| 126 | 
         
            +
                    cfg.train.batch_size,
         
     | 
| 127 | 
         
            +
                    cfg.train.sampler.holistic_shuffle,
         
     | 
| 128 | 
         
            +
                    logger,
         
     | 
| 129 | 
         
            +
                    loader_type,
         
     | 
| 130 | 
         
            +
                )
         
     | 
| 131 | 
         
            +
                batch_sampler = BatchSampler(
         
     | 
| 132 | 
         
            +
                    sampler,
         
     | 
| 133 | 
         
            +
                    cfg.train.batch_size,
         
     | 
| 134 | 
         
            +
                    cfg.train.sampler.drop_last if not loader_type == "valid" else False,
         
     | 
| 135 | 
         
            +
                )
         
     | 
| 136 | 
         
            +
                return sampler, batch_sampler
         
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
            class VariableSampler(BatchSampler):
         
     | 
| 140 | 
         
            +
                def __init__(self, sampler, drop_last: bool, use_random_sampler=False):
         
     | 
| 141 | 
         
            +
                    self.data_list = sampler
         
     | 
| 142 | 
         
            +
                    if use_random_sampler:
         
     | 
| 143 | 
         
            +
                        self.sampler = RandomSampler(sampler)
         
     | 
| 144 | 
         
            +
                    else:
         
     | 
| 145 | 
         
            +
                        self.sampler = SequentialSampler(sampler)
         
     | 
| 146 | 
         
            +
             
     | 
| 147 | 
         
            +
                    super().__init__(self.sampler, 1, drop_last)
         
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
                def __iter__(self):
         
     | 
| 150 | 
         
            +
                    for batch_ids in self.data_list:
         
     | 
| 151 | 
         
            +
                        yield batch_ids
         
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
                def __len__(self):
         
     | 
| 154 | 
         
            +
                    if self.drop_last:
         
     | 
| 155 | 
         
            +
                        return len(self.sampler) // self.batch_size
         
     | 
| 156 | 
         
            +
                    else:
         
     | 
| 157 | 
         
            +
                        return (len(self.sampler) + self.batch_size - 1) // self.batch_size
         
     | 
    	
        models/base/base_trainer.py
    ADDED
    
    | 
         @@ -0,0 +1,348 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2023 Amphion.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import collections
         
     | 
| 7 | 
         
            +
            import json
         
     | 
| 8 | 
         
            +
            import os
         
     | 
| 9 | 
         
            +
            import sys
         
     | 
| 10 | 
         
            +
            import time
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            import torch
         
     | 
| 13 | 
         
            +
            import torch.distributed as dist
         
     | 
| 14 | 
         
            +
            from torch.nn.parallel import DistributedDataParallel
         
     | 
| 15 | 
         
            +
            from torch.utils.data import ConcatDataset, DataLoader
         
     | 
| 16 | 
         
            +
            from torch.utils.tensorboard import SummaryWriter
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            from models.base.base_sampler import BatchSampler
         
     | 
| 19 | 
         
            +
            from utils.util import (
         
     | 
| 20 | 
         
            +
                Logger,
         
     | 
| 21 | 
         
            +
                remove_older_ckpt,
         
     | 
| 22 | 
         
            +
                save_config,
         
     | 
| 23 | 
         
            +
                set_all_random_seed,
         
     | 
| 24 | 
         
            +
                ValueWindow,
         
     | 
| 25 | 
         
            +
            )
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            class BaseTrainer(object):
         
     | 
| 29 | 
         
            +
                def __init__(self, args, cfg):
         
     | 
| 30 | 
         
            +
                    self.args = args
         
     | 
| 31 | 
         
            +
                    self.log_dir = args.log_dir
         
     | 
| 32 | 
         
            +
                    self.cfg = cfg
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                    self.checkpoint_dir = os.path.join(args.log_dir, "checkpoints")
         
     | 
| 35 | 
         
            +
                    os.makedirs(self.checkpoint_dir, exist_ok=True)
         
     | 
| 36 | 
         
            +
                    if not cfg.train.ddp or args.local_rank == 0:
         
     | 
| 37 | 
         
            +
                        self.sw = SummaryWriter(os.path.join(args.log_dir, "events"))
         
     | 
| 38 | 
         
            +
                        self.logger = self.build_logger()
         
     | 
| 39 | 
         
            +
                    self.time_window = ValueWindow(50)
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                    self.step = 0
         
     | 
| 42 | 
         
            +
                    self.epoch = -1
         
     | 
| 43 | 
         
            +
                    self.max_epochs = self.cfg.train.epochs
         
     | 
| 44 | 
         
            +
                    self.max_steps = self.cfg.train.max_steps
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                    # set random seed & init distributed training
         
     | 
| 47 | 
         
            +
                    set_all_random_seed(self.cfg.train.random_seed)
         
     | 
| 48 | 
         
            +
                    if cfg.train.ddp:
         
     | 
| 49 | 
         
            +
                        dist.init_process_group(backend="nccl")
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                    if cfg.model_type not in ["AutoencoderKL", "AudioLDM"]:
         
     | 
| 52 | 
         
            +
                        self.singers = self.build_singers_lut()
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                    # setup data_loader
         
     | 
| 55 | 
         
            +
                    self.data_loader = self.build_data_loader()
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                    # setup model & enable distributed training
         
     | 
| 58 | 
         
            +
                    self.model = self.build_model()
         
     | 
| 59 | 
         
            +
                    print(self.model)
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                    if isinstance(self.model, dict):
         
     | 
| 62 | 
         
            +
                        for key, value in self.model.items():
         
     | 
| 63 | 
         
            +
                            value.cuda(self.args.local_rank)
         
     | 
| 64 | 
         
            +
                            if key == "PQMF":
         
     | 
| 65 | 
         
            +
                                continue
         
     | 
| 66 | 
         
            +
                            if cfg.train.ddp:
         
     | 
| 67 | 
         
            +
                                self.model[key] = DistributedDataParallel(
         
     | 
| 68 | 
         
            +
                                    value, device_ids=[self.args.local_rank]
         
     | 
| 69 | 
         
            +
                                )
         
     | 
| 70 | 
         
            +
                    else:
         
     | 
| 71 | 
         
            +
                        self.model.cuda(self.args.local_rank)
         
     | 
| 72 | 
         
            +
                        if cfg.train.ddp:
         
     | 
| 73 | 
         
            +
                            self.model = DistributedDataParallel(
         
     | 
| 74 | 
         
            +
                                self.model, device_ids=[self.args.local_rank]
         
     | 
| 75 | 
         
            +
                            )
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
                    # create criterion
         
     | 
| 78 | 
         
            +
                    self.criterion = self.build_criterion()
         
     | 
| 79 | 
         
            +
                    if isinstance(self.criterion, dict):
         
     | 
| 80 | 
         
            +
                        for key, value in self.criterion.items():
         
     | 
| 81 | 
         
            +
                            self.criterion[key].cuda(args.local_rank)
         
     | 
| 82 | 
         
            +
                    else:
         
     | 
| 83 | 
         
            +
                        self.criterion.cuda(self.args.local_rank)
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
                    # optimizer
         
     | 
| 86 | 
         
            +
                    self.optimizer = self.build_optimizer()
         
     | 
| 87 | 
         
            +
                    self.scheduler = self.build_scheduler()
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                    # save config file
         
     | 
| 90 | 
         
            +
                    self.config_save_path = os.path.join(self.checkpoint_dir, "args.json")
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                def build_logger(self):
         
     | 
| 93 | 
         
            +
                    log_file = os.path.join(self.checkpoint_dir, "train.log")
         
     | 
| 94 | 
         
            +
                    logger = Logger(log_file, level=self.args.log_level).logger
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
                    return logger
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
                def build_dataset(self):
         
     | 
| 99 | 
         
            +
                    raise NotImplementedError
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
                def build_data_loader(self):
         
     | 
| 102 | 
         
            +
                    Dataset, Collator = self.build_dataset()
         
     | 
| 103 | 
         
            +
                    # build dataset instance for each dataset and combine them by ConcatDataset
         
     | 
| 104 | 
         
            +
                    datasets_list = []
         
     | 
| 105 | 
         
            +
                    for dataset in self.cfg.dataset:
         
     | 
| 106 | 
         
            +
                        subdataset = Dataset(self.cfg, dataset, is_valid=False)
         
     | 
| 107 | 
         
            +
                        datasets_list.append(subdataset)
         
     | 
| 108 | 
         
            +
                    train_dataset = ConcatDataset(datasets_list)
         
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
                    train_collate = Collator(self.cfg)
         
     | 
| 111 | 
         
            +
                    # TODO: multi-GPU training
         
     | 
| 112 | 
         
            +
                    if self.cfg.train.ddp:
         
     | 
| 113 | 
         
            +
                        raise NotImplementedError("DDP is not supported yet.")
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
                    # sampler will provide indices to batch_sampler, which will perform batching and yield batch indices
         
     | 
| 116 | 
         
            +
                    batch_sampler = BatchSampler(
         
     | 
| 117 | 
         
            +
                        cfg=self.cfg, concat_dataset=train_dataset, dataset_list=datasets_list
         
     | 
| 118 | 
         
            +
                    )
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
                    # use batch_sampler argument instead of (sampler, shuffle, drop_last, batch_size)
         
     | 
| 121 | 
         
            +
                    train_loader = DataLoader(
         
     | 
| 122 | 
         
            +
                        train_dataset,
         
     | 
| 123 | 
         
            +
                        collate_fn=train_collate,
         
     | 
| 124 | 
         
            +
                        num_workers=self.args.num_workers,
         
     | 
| 125 | 
         
            +
                        batch_sampler=batch_sampler,
         
     | 
| 126 | 
         
            +
                        pin_memory=False,
         
     | 
| 127 | 
         
            +
                    )
         
     | 
| 128 | 
         
            +
                    if not self.cfg.train.ddp or self.args.local_rank == 0:
         
     | 
| 129 | 
         
            +
                        datasets_list = []
         
     | 
| 130 | 
         
            +
                        for dataset in self.cfg.dataset:
         
     | 
| 131 | 
         
            +
                            subdataset = Dataset(self.cfg, dataset, is_valid=True)
         
     | 
| 132 | 
         
            +
                            datasets_list.append(subdataset)
         
     | 
| 133 | 
         
            +
                        valid_dataset = ConcatDataset(datasets_list)
         
     | 
| 134 | 
         
            +
                        valid_collate = Collator(self.cfg)
         
     | 
| 135 | 
         
            +
                        batch_sampler = BatchSampler(
         
     | 
| 136 | 
         
            +
                            cfg=self.cfg, concat_dataset=valid_dataset, dataset_list=datasets_list
         
     | 
| 137 | 
         
            +
                        )
         
     | 
| 138 | 
         
            +
                        valid_loader = DataLoader(
         
     | 
| 139 | 
         
            +
                            valid_dataset,
         
     | 
| 140 | 
         
            +
                            collate_fn=valid_collate,
         
     | 
| 141 | 
         
            +
                            num_workers=1,
         
     | 
| 142 | 
         
            +
                            batch_sampler=batch_sampler,
         
     | 
| 143 | 
         
            +
                        )
         
     | 
| 144 | 
         
            +
                    else:
         
     | 
| 145 | 
         
            +
                        raise NotImplementedError("DDP is not supported yet.")
         
     | 
| 146 | 
         
            +
                        # valid_loader = None
         
     | 
| 147 | 
         
            +
                    data_loader = {"train": train_loader, "valid": valid_loader}
         
     | 
| 148 | 
         
            +
                    return data_loader
         
     | 
| 149 | 
         
            +
             
     | 
| 150 | 
         
            +
                def build_singers_lut(self):
         
     | 
| 151 | 
         
            +
                    # combine singers
         
     | 
| 152 | 
         
            +
                    if not os.path.exists(os.path.join(self.log_dir, self.cfg.preprocess.spk2id)):
         
     | 
| 153 | 
         
            +
                        singers = collections.OrderedDict()
         
     | 
| 154 | 
         
            +
                    else:
         
     | 
| 155 | 
         
            +
                        with open(
         
     | 
| 156 | 
         
            +
                            os.path.join(self.log_dir, self.cfg.preprocess.spk2id), "r"
         
     | 
| 157 | 
         
            +
                        ) as singer_file:
         
     | 
| 158 | 
         
            +
                            singers = json.load(singer_file)
         
     | 
| 159 | 
         
            +
                    singer_count = len(singers)
         
     | 
| 160 | 
         
            +
                    for dataset in self.cfg.dataset:
         
     | 
| 161 | 
         
            +
                        singer_lut_path = os.path.join(
         
     | 
| 162 | 
         
            +
                            self.cfg.preprocess.processed_dir, dataset, self.cfg.preprocess.spk2id
         
     | 
| 163 | 
         
            +
                        )
         
     | 
| 164 | 
         
            +
                        with open(singer_lut_path, "r") as singer_lut_path:
         
     | 
| 165 | 
         
            +
                            singer_lut = json.load(singer_lut_path)
         
     | 
| 166 | 
         
            +
                        for singer in singer_lut.keys():
         
     | 
| 167 | 
         
            +
                            if singer not in singers:
         
     | 
| 168 | 
         
            +
                                singers[singer] = singer_count
         
     | 
| 169 | 
         
            +
                                singer_count += 1
         
     | 
| 170 | 
         
            +
                    with open(
         
     | 
| 171 | 
         
            +
                        os.path.join(self.log_dir, self.cfg.preprocess.spk2id), "w"
         
     | 
| 172 | 
         
            +
                    ) as singer_file:
         
     | 
| 173 | 
         
            +
                        json.dump(singers, singer_file, indent=4, ensure_ascii=False)
         
     | 
| 174 | 
         
            +
                    print(
         
     | 
| 175 | 
         
            +
                        "singers have been dumped to {}".format(
         
     | 
| 176 | 
         
            +
                            os.path.join(self.log_dir, self.cfg.preprocess.spk2id)
         
     | 
| 177 | 
         
            +
                        )
         
     | 
| 178 | 
         
            +
                    )
         
     | 
| 179 | 
         
            +
                    return singers
         
     | 
| 180 | 
         
            +
             
     | 
| 181 | 
         
            +
                def build_model(self):
         
     | 
| 182 | 
         
            +
                    raise NotImplementedError()
         
     | 
| 183 | 
         
            +
             
     | 
| 184 | 
         
            +
                def build_optimizer(self):
         
     | 
| 185 | 
         
            +
                    raise NotImplementedError
         
     | 
| 186 | 
         
            +
             
     | 
| 187 | 
         
            +
                def build_scheduler(self):
         
     | 
| 188 | 
         
            +
                    raise NotImplementedError()
         
     | 
| 189 | 
         
            +
             
     | 
| 190 | 
         
            +
                def build_criterion(self):
         
     | 
| 191 | 
         
            +
                    raise NotImplementedError
         
     | 
| 192 | 
         
            +
             
     | 
| 193 | 
         
            +
                def get_state_dict(self):
         
     | 
| 194 | 
         
            +
                    raise NotImplementedError
         
     | 
| 195 | 
         
            +
             
     | 
| 196 | 
         
            +
                def save_config_file(self):
         
     | 
| 197 | 
         
            +
                    save_config(self.config_save_path, self.cfg)
         
     | 
| 198 | 
         
            +
             
     | 
| 199 | 
         
            +
                # TODO, save without module.
         
     | 
| 200 | 
         
            +
                def save_checkpoint(self, state_dict, saved_model_path):
         
     | 
| 201 | 
         
            +
                    torch.save(state_dict, saved_model_path)
         
     | 
| 202 | 
         
            +
             
     | 
| 203 | 
         
            +
                def load_checkpoint(self):
         
     | 
| 204 | 
         
            +
                    checkpoint_path = os.path.join(self.checkpoint_dir, "checkpoint")
         
     | 
| 205 | 
         
            +
                    assert os.path.exists(checkpoint_path)
         
     | 
| 206 | 
         
            +
                    checkpoint_filename = open(checkpoint_path).readlines()[-1].strip()
         
     | 
| 207 | 
         
            +
                    model_path = os.path.join(self.checkpoint_dir, checkpoint_filename)
         
     | 
| 208 | 
         
            +
                    assert os.path.exists(model_path)
         
     | 
| 209 | 
         
            +
                    if not self.cfg.train.ddp or self.args.local_rank == 0:
         
     | 
| 210 | 
         
            +
                        self.logger.info(f"Re(store) from {model_path}")
         
     | 
| 211 | 
         
            +
                    checkpoint = torch.load(model_path, map_location="cpu")
         
     | 
| 212 | 
         
            +
                    return checkpoint
         
     | 
| 213 | 
         
            +
             
     | 
| 214 | 
         
            +
                def load_model(self, checkpoint):
         
     | 
| 215 | 
         
            +
                    raise NotImplementedError
         
     | 
| 216 | 
         
            +
             
     | 
| 217 | 
         
            +
                def restore(self):
         
     | 
| 218 | 
         
            +
                    checkpoint = self.load_checkpoint()
         
     | 
| 219 | 
         
            +
                    self.load_model(checkpoint)
         
     | 
| 220 | 
         
            +
             
     | 
| 221 | 
         
            +
                def train_step(self, data):
         
     | 
| 222 | 
         
            +
                    raise NotImplementedError(
         
     | 
| 223 | 
         
            +
                        f"Need to implement function {sys._getframe().f_code.co_name} in "
         
     | 
| 224 | 
         
            +
                        f"your sub-class of {self.__class__.__name__}. "
         
     | 
| 225 | 
         
            +
                    )
         
     | 
| 226 | 
         
            +
             
     | 
| 227 | 
         
            +
                @torch.no_grad()
         
     | 
| 228 | 
         
            +
                def eval_step(self):
         
     | 
| 229 | 
         
            +
                    raise NotImplementedError(
         
     | 
| 230 | 
         
            +
                        f"Need to implement function {sys._getframe().f_code.co_name} in "
         
     | 
| 231 | 
         
            +
                        f"your sub-class of {self.__class__.__name__}. "
         
     | 
| 232 | 
         
            +
                    )
         
     | 
| 233 | 
         
            +
             
     | 
| 234 | 
         
            +
                def write_summary(self, losses, stats):
         
     | 
| 235 | 
         
            +
                    raise NotImplementedError(
         
     | 
| 236 | 
         
            +
                        f"Need to implement function {sys._getframe().f_code.co_name} in "
         
     | 
| 237 | 
         
            +
                        f"your sub-class of {self.__class__.__name__}. "
         
     | 
| 238 | 
         
            +
                    )
         
     | 
| 239 | 
         
            +
             
     | 
| 240 | 
         
            +
                def write_valid_summary(self, losses, stats):
         
     | 
| 241 | 
         
            +
                    raise NotImplementedError(
         
     | 
| 242 | 
         
            +
                        f"Need to implement function {sys._getframe().f_code.co_name} in "
         
     | 
| 243 | 
         
            +
                        f"your sub-class of {self.__class__.__name__}. "
         
     | 
| 244 | 
         
            +
                    )
         
     | 
| 245 | 
         
            +
             
     | 
| 246 | 
         
            +
                def echo_log(self, losses, mode="Training"):
         
     | 
| 247 | 
         
            +
                    message = [
         
     | 
| 248 | 
         
            +
                        "{} - Epoch {} Step {}: [{:.3f} s/step]".format(
         
     | 
| 249 | 
         
            +
                            mode, self.epoch + 1, self.step, self.time_window.average
         
     | 
| 250 | 
         
            +
                        )
         
     | 
| 251 | 
         
            +
                    ]
         
     | 
| 252 | 
         
            +
             
     | 
| 253 | 
         
            +
                    for key in sorted(losses.keys()):
         
     | 
| 254 | 
         
            +
                        if isinstance(losses[key], dict):
         
     | 
| 255 | 
         
            +
                            for k, v in losses[key].items():
         
     | 
| 256 | 
         
            +
                                message.append(
         
     | 
| 257 | 
         
            +
                                    str(k).split("/")[-1] + "=" + str(round(float(v), 5))
         
     | 
| 258 | 
         
            +
                                )
         
     | 
| 259 | 
         
            +
                        else:
         
     | 
| 260 | 
         
            +
                            message.append(
         
     | 
| 261 | 
         
            +
                                str(key).split("/")[-1] + "=" + str(round(float(losses[key]), 5))
         
     | 
| 262 | 
         
            +
                            )
         
     | 
| 263 | 
         
            +
                    self.logger.info(", ".join(message))
         
     | 
| 264 | 
         
            +
             
     | 
| 265 | 
         
            +
                def eval_epoch(self):
         
     | 
| 266 | 
         
            +
                    self.logger.info("Validation...")
         
     | 
| 267 | 
         
            +
                    valid_losses = {}
         
     | 
| 268 | 
         
            +
                    for i, batch_data in enumerate(self.data_loader["valid"]):
         
     | 
| 269 | 
         
            +
                        for k, v in batch_data.items():
         
     | 
| 270 | 
         
            +
                            if isinstance(v, torch.Tensor):
         
     | 
| 271 | 
         
            +
                                batch_data[k] = v.cuda()
         
     | 
| 272 | 
         
            +
                        valid_loss, valid_stats, total_valid_loss = self.eval_step(batch_data, i)
         
     | 
| 273 | 
         
            +
                        for key in valid_loss:
         
     | 
| 274 | 
         
            +
                            if key not in valid_losses:
         
     | 
| 275 | 
         
            +
                                valid_losses[key] = 0
         
     | 
| 276 | 
         
            +
                            valid_losses[key] += valid_loss[key]
         
     | 
| 277 | 
         
            +
             
     | 
| 278 | 
         
            +
                    # Add mel and audio to the Tensorboard
         
     | 
| 279 | 
         
            +
                    # Average loss
         
     | 
| 280 | 
         
            +
                    for key in valid_losses:
         
     | 
| 281 | 
         
            +
                        valid_losses[key] /= i + 1
         
     | 
| 282 | 
         
            +
                    self.echo_log(valid_losses, "Valid")
         
     | 
| 283 | 
         
            +
                    return valid_losses, valid_stats
         
     | 
| 284 | 
         
            +
             
     | 
| 285 | 
         
            +
                def train_epoch(self):
         
     | 
| 286 | 
         
            +
                    for i, batch_data in enumerate(self.data_loader["train"]):
         
     | 
| 287 | 
         
            +
                        start_time = time.time()
         
     | 
| 288 | 
         
            +
                        # Put the data to cuda device
         
     | 
| 289 | 
         
            +
                        for k, v in batch_data.items():
         
     | 
| 290 | 
         
            +
                            if isinstance(v, torch.Tensor):
         
     | 
| 291 | 
         
            +
                                batch_data[k] = v.cuda(self.args.local_rank)
         
     | 
| 292 | 
         
            +
             
     | 
| 293 | 
         
            +
                        # Training step
         
     | 
| 294 | 
         
            +
                        train_losses, train_stats, total_loss = self.train_step(batch_data)
         
     | 
| 295 | 
         
            +
                        self.time_window.append(time.time() - start_time)
         
     | 
| 296 | 
         
            +
             
     | 
| 297 | 
         
            +
                        if self.args.local_rank == 0 or not self.cfg.train.ddp:
         
     | 
| 298 | 
         
            +
                            if self.step % self.args.stdout_interval == 0:
         
     | 
| 299 | 
         
            +
                                self.echo_log(train_losses, "Training")
         
     | 
| 300 | 
         
            +
             
     | 
| 301 | 
         
            +
                            if self.step % self.cfg.train.save_summary_steps == 0:
         
     | 
| 302 | 
         
            +
                                self.logger.info(f"Save summary as step {self.step}")
         
     | 
| 303 | 
         
            +
                                self.write_summary(train_losses, train_stats)
         
     | 
| 304 | 
         
            +
             
     | 
| 305 | 
         
            +
                            if (
         
     | 
| 306 | 
         
            +
                                self.step % self.cfg.train.save_checkpoints_steps == 0
         
     | 
| 307 | 
         
            +
                                and self.step != 0
         
     | 
| 308 | 
         
            +
                            ):
         
     | 
| 309 | 
         
            +
                                saved_model_name = "step-{:07d}_loss-{:.4f}.pt".format(
         
     | 
| 310 | 
         
            +
                                    self.step, total_loss
         
     | 
| 311 | 
         
            +
                                )
         
     | 
| 312 | 
         
            +
                                saved_model_path = os.path.join(
         
     | 
| 313 | 
         
            +
                                    self.checkpoint_dir, saved_model_name
         
     | 
| 314 | 
         
            +
                                )
         
     | 
| 315 | 
         
            +
                                saved_state_dict = self.get_state_dict()
         
     | 
| 316 | 
         
            +
                                self.save_checkpoint(saved_state_dict, saved_model_path)
         
     | 
| 317 | 
         
            +
                                self.save_config_file()
         
     | 
| 318 | 
         
            +
                                # keep max n models
         
     | 
| 319 | 
         
            +
                                remove_older_ckpt(
         
     | 
| 320 | 
         
            +
                                    saved_model_name,
         
     | 
| 321 | 
         
            +
                                    self.checkpoint_dir,
         
     | 
| 322 | 
         
            +
                                    max_to_keep=self.cfg.train.keep_checkpoint_max,
         
     | 
| 323 | 
         
            +
                                )
         
     | 
| 324 | 
         
            +
             
     | 
| 325 | 
         
            +
                            if self.step != 0 and self.step % self.cfg.train.valid_interval == 0:
         
     | 
| 326 | 
         
            +
                                if isinstance(self.model, dict):
         
     | 
| 327 | 
         
            +
                                    for key in self.model.keys():
         
     | 
| 328 | 
         
            +
                                        self.model[key].eval()
         
     | 
| 329 | 
         
            +
                                else:
         
     | 
| 330 | 
         
            +
                                    self.model.eval()
         
     | 
| 331 | 
         
            +
                                # Evaluate one epoch and get average loss
         
     | 
| 332 | 
         
            +
                                valid_losses, valid_stats = self.eval_epoch()
         
     | 
| 333 | 
         
            +
                                if isinstance(self.model, dict):
         
     | 
| 334 | 
         
            +
                                    for key in self.model.keys():
         
     | 
| 335 | 
         
            +
                                        self.model[key].train()
         
     | 
| 336 | 
         
            +
                                else:
         
     | 
| 337 | 
         
            +
                                    self.model.train()
         
     | 
| 338 | 
         
            +
                                # Write validation losses to summary.
         
     | 
| 339 | 
         
            +
                                self.write_valid_summary(valid_losses, valid_stats)
         
     | 
| 340 | 
         
            +
                        self.step += 1
         
     | 
| 341 | 
         
            +
             
     | 
| 342 | 
         
            +
                def train(self):
         
     | 
| 343 | 
         
            +
                    for epoch in range(max(0, self.epoch), self.max_epochs):
         
     | 
| 344 | 
         
            +
                        self.train_epoch()
         
     | 
| 345 | 
         
            +
                        self.epoch += 1
         
     | 
| 346 | 
         
            +
                        if self.step > self.max_steps:
         
     | 
| 347 | 
         
            +
                            self.logger.info("Training finished!")
         
     | 
| 348 | 
         
            +
                            break
         
     | 
    	
        models/base/new_dataset.py
    ADDED
    
    | 
         @@ -0,0 +1,50 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2023 Amphion.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import json
         
     | 
| 7 | 
         
            +
            import os
         
     | 
| 8 | 
         
            +
            from abc import abstractmethod
         
     | 
| 9 | 
         
            +
            from pathlib import Path
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            import json5
         
     | 
| 12 | 
         
            +
            import torch
         
     | 
| 13 | 
         
            +
            import yaml
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            # TODO: for training and validating
         
     | 
| 17 | 
         
            +
            class BaseDataset(torch.utils.data.Dataset):
         
     | 
| 18 | 
         
            +
                r"""Base dataset for training and validating."""
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
                def __init__(self, args, cfg, is_valid=False):
         
     | 
| 21 | 
         
            +
                    pass
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            class BaseTestDataset(torch.utils.data.Dataset):
         
     | 
| 25 | 
         
            +
                r"""Test dataset for inference."""
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
         
     | 
| 28 | 
         
            +
                    assert infer_type in ["from_dataset", "from_file"]
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                    self.args = args
         
     | 
| 31 | 
         
            +
                    self.cfg = cfg
         
     | 
| 32 | 
         
            +
                    self.infer_type = infer_type
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                @abstractmethod
         
     | 
| 35 | 
         
            +
                def __getitem__(self, index):
         
     | 
| 36 | 
         
            +
                    pass
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                def __len__(self):
         
     | 
| 39 | 
         
            +
                    return len(self.metadata)
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                def get_metadata(self):
         
     | 
| 42 | 
         
            +
                    path = Path(self.args.source)
         
     | 
| 43 | 
         
            +
                    if path.suffix == ".json" or path.suffix == ".jsonc":
         
     | 
| 44 | 
         
            +
                        metadata = json5.load(open(self.args.source, "r"))
         
     | 
| 45 | 
         
            +
                    elif path.suffix == ".yaml" or path.suffix == ".yml":
         
     | 
| 46 | 
         
            +
                        metadata = yaml.full_load(open(self.args.source, "r"))
         
     | 
| 47 | 
         
            +
                    else:
         
     | 
| 48 | 
         
            +
                        raise ValueError(f"Unsupported file type: {path.suffix}")
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                    return metadata
         
     | 
    	
        models/base/new_inference.py
    ADDED
    
    | 
         @@ -0,0 +1,253 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2023 Amphion.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import os
         
     | 
| 7 | 
         
            +
            import random
         
     | 
| 8 | 
         
            +
            import re
         
     | 
| 9 | 
         
            +
            import time
         
     | 
| 10 | 
         
            +
            from abc import abstractmethod
         
     | 
| 11 | 
         
            +
            from pathlib import Path
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            import accelerate
         
     | 
| 14 | 
         
            +
            import json5
         
     | 
| 15 | 
         
            +
            import numpy as np
         
     | 
| 16 | 
         
            +
            import torch
         
     | 
| 17 | 
         
            +
            from accelerate.logging import get_logger
         
     | 
| 18 | 
         
            +
            from torch.utils.data import DataLoader
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            from models.vocoders.vocoder_inference import synthesis
         
     | 
| 21 | 
         
            +
            from utils.io import save_audio
         
     | 
| 22 | 
         
            +
            from utils.util import load_config
         
     | 
| 23 | 
         
            +
            from utils.audio_slicer import is_silence
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            EPS = 1.0e-12
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            class BaseInference(object):
         
     | 
| 29 | 
         
            +
                def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
         
     | 
| 30 | 
         
            +
                    super().__init__()
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                    start = time.monotonic_ns()
         
     | 
| 33 | 
         
            +
                    self.args = args
         
     | 
| 34 | 
         
            +
                    self.cfg = cfg
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
                    assert infer_type in ["from_dataset", "from_file"]
         
     | 
| 37 | 
         
            +
                    self.infer_type = infer_type
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
                    # init with accelerate
         
     | 
| 40 | 
         
            +
                    self.accelerator = accelerate.Accelerator()
         
     | 
| 41 | 
         
            +
                    self.accelerator.wait_for_everyone()
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                    # Use accelerate logger for distributed inference
         
     | 
| 44 | 
         
            +
                    with self.accelerator.main_process_first():
         
     | 
| 45 | 
         
            +
                        self.logger = get_logger("inference", log_level=args.log_level)
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                    # Log some info
         
     | 
| 48 | 
         
            +
                    self.logger.info("=" * 56)
         
     | 
| 49 | 
         
            +
                    self.logger.info("||\t\t" + "New inference process started." + "\t\t||")
         
     | 
| 50 | 
         
            +
                    self.logger.info("=" * 56)
         
     | 
| 51 | 
         
            +
                    self.logger.info("\n")
         
     | 
| 52 | 
         
            +
                    self.logger.debug(f"Using {args.log_level.upper()} logging level.")
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                    self.acoustics_dir = args.acoustics_dir
         
     | 
| 55 | 
         
            +
                    self.logger.debug(f"Acoustic dir: {args.acoustics_dir}")
         
     | 
| 56 | 
         
            +
                    self.vocoder_dir = args.vocoder_dir
         
     | 
| 57 | 
         
            +
                    self.logger.debug(f"Vocoder dir: {args.vocoder_dir}")
         
     | 
| 58 | 
         
            +
                    # should be in svc inferencer
         
     | 
| 59 | 
         
            +
                    # self.target_singer = args.target_singer
         
     | 
| 60 | 
         
            +
                    # self.logger.info(f"Target singers: {args.target_singer}")
         
     | 
| 61 | 
         
            +
                    # self.trans_key = args.trans_key
         
     | 
| 62 | 
         
            +
                    # self.logger.info(f"Trans key: {args.trans_key}")
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                    os.makedirs(args.output_dir, exist_ok=True)
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                    # set random seed
         
     | 
| 67 | 
         
            +
                    with self.accelerator.main_process_first():
         
     | 
| 68 | 
         
            +
                        start = time.monotonic_ns()
         
     | 
| 69 | 
         
            +
                        self._set_random_seed(self.cfg.train.random_seed)
         
     | 
| 70 | 
         
            +
                        end = time.monotonic_ns()
         
     | 
| 71 | 
         
            +
                        self.logger.debug(
         
     | 
| 72 | 
         
            +
                            f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
         
     | 
| 73 | 
         
            +
                        )
         
     | 
| 74 | 
         
            +
                        self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                    # setup data_loader
         
     | 
| 77 | 
         
            +
                    with self.accelerator.main_process_first():
         
     | 
| 78 | 
         
            +
                        self.logger.info("Building dataset...")
         
     | 
| 79 | 
         
            +
                        start = time.monotonic_ns()
         
     | 
| 80 | 
         
            +
                        self.test_dataloader = self._build_dataloader()
         
     | 
| 81 | 
         
            +
                        end = time.monotonic_ns()
         
     | 
| 82 | 
         
            +
                        self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms")
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
                    # setup model
         
     | 
| 85 | 
         
            +
                    with self.accelerator.main_process_first():
         
     | 
| 86 | 
         
            +
                        self.logger.info("Building model...")
         
     | 
| 87 | 
         
            +
                        start = time.monotonic_ns()
         
     | 
| 88 | 
         
            +
                        self.model = self._build_model()
         
     | 
| 89 | 
         
            +
                        end = time.monotonic_ns()
         
     | 
| 90 | 
         
            +
                        # self.logger.debug(self.model)
         
     | 
| 91 | 
         
            +
                        self.logger.info(f"Building model done in {(end - start) / 1e6:.3f}ms")
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
                    # init with accelerate
         
     | 
| 94 | 
         
            +
                    self.logger.info("Initializing accelerate...")
         
     | 
| 95 | 
         
            +
                    start = time.monotonic_ns()
         
     | 
| 96 | 
         
            +
                    self.accelerator = accelerate.Accelerator()
         
     | 
| 97 | 
         
            +
                    self.model = self.accelerator.prepare(self.model)
         
     | 
| 98 | 
         
            +
                    end = time.monotonic_ns()
         
     | 
| 99 | 
         
            +
                    self.accelerator.wait_for_everyone()
         
     | 
| 100 | 
         
            +
                    self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.3f}ms")
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                    with self.accelerator.main_process_first():
         
     | 
| 103 | 
         
            +
                        self.logger.info("Loading checkpoint...")
         
     | 
| 104 | 
         
            +
                        start = time.monotonic_ns()
         
     | 
| 105 | 
         
            +
                        # TODO: Also, suppose only use latest one yet
         
     | 
| 106 | 
         
            +
                        self.__load_model(os.path.join(args.acoustics_dir, "checkpoint"))
         
     | 
| 107 | 
         
            +
                        end = time.monotonic_ns()
         
     | 
| 108 | 
         
            +
                        self.logger.info(f"Loading checkpoint done in {(end - start) / 1e6:.3f}ms")
         
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
                    self.model.eval()
         
     | 
| 111 | 
         
            +
                    self.accelerator.wait_for_everyone()
         
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
                ### Abstract methods ###
         
     | 
| 114 | 
         
            +
                @abstractmethod
         
     | 
| 115 | 
         
            +
                def _build_test_dataset(self):
         
     | 
| 116 | 
         
            +
                    pass
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
                @abstractmethod
         
     | 
| 119 | 
         
            +
                def _build_model(self):
         
     | 
| 120 | 
         
            +
                    pass
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
                @abstractmethod
         
     | 
| 123 | 
         
            +
                @torch.inference_mode()
         
     | 
| 124 | 
         
            +
                def _inference_each_batch(self, batch_data):
         
     | 
| 125 | 
         
            +
                    pass
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                ### Abstract methods end ###
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
                @torch.inference_mode()
         
     | 
| 130 | 
         
            +
                def inference(self):
         
     | 
| 131 | 
         
            +
                    for i, batch in enumerate(self.test_dataloader):
         
     | 
| 132 | 
         
            +
                        y_pred = self._inference_each_batch(batch).cpu()
         
     | 
| 133 | 
         
            +
             
     | 
| 134 | 
         
            +
                        # Judge whether the min-max normliazation is used
         
     | 
| 135 | 
         
            +
                        if self.cfg.preprocess.use_min_max_norm_mel:
         
     | 
| 136 | 
         
            +
                            mel_min, mel_max = self.test_dataset.target_mel_extrema
         
     | 
| 137 | 
         
            +
                            y_pred = (y_pred + 1.0) / 2.0 * (mel_max - mel_min + EPS) + mel_min
         
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
                        y_ls = y_pred.chunk(self.test_batch_size)
         
     | 
| 140 | 
         
            +
                        tgt_ls = batch["target_len"].cpu().chunk(self.test_batch_size)
         
     | 
| 141 | 
         
            +
                        j = 0
         
     | 
| 142 | 
         
            +
                        for it, l in zip(y_ls, tgt_ls):
         
     | 
| 143 | 
         
            +
                            l = l.item()
         
     | 
| 144 | 
         
            +
                            it = it.squeeze(0)[:l]
         
     | 
| 145 | 
         
            +
                            uid = self.test_dataset.metadata[i * self.test_batch_size + j]["Uid"]
         
     | 
| 146 | 
         
            +
                            torch.save(it, os.path.join(self.args.output_dir, f"{uid}.pt"))
         
     | 
| 147 | 
         
            +
                            j += 1
         
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
                    vocoder_cfg, vocoder_ckpt = self._parse_vocoder(self.args.vocoder_dir)
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
                    res = synthesis(
         
     | 
| 152 | 
         
            +
                        cfg=vocoder_cfg,
         
     | 
| 153 | 
         
            +
                        vocoder_weight_file=vocoder_ckpt,
         
     | 
| 154 | 
         
            +
                        n_samples=None,
         
     | 
| 155 | 
         
            +
                        pred=[
         
     | 
| 156 | 
         
            +
                            torch.load(
         
     | 
| 157 | 
         
            +
                                os.path.join(self.args.output_dir, "{}.pt".format(i["Uid"]))
         
     | 
| 158 | 
         
            +
                            ).numpy(force=True)
         
     | 
| 159 | 
         
            +
                            for i in self.test_dataset.metadata
         
     | 
| 160 | 
         
            +
                        ],
         
     | 
| 161 | 
         
            +
                    )
         
     | 
| 162 | 
         
            +
             
     | 
| 163 | 
         
            +
                    output_audio_files = []
         
     | 
| 164 | 
         
            +
                    for it, wav in zip(self.test_dataset.metadata, res):
         
     | 
| 165 | 
         
            +
                        uid = it["Uid"]
         
     | 
| 166 | 
         
            +
                        file = os.path.join(self.args.output_dir, f"{uid}.wav")
         
     | 
| 167 | 
         
            +
                        output_audio_files.append(file)
         
     | 
| 168 | 
         
            +
             
     | 
| 169 | 
         
            +
                        wav = wav.numpy(force=True)
         
     | 
| 170 | 
         
            +
                        save_audio(
         
     | 
| 171 | 
         
            +
                            file,
         
     | 
| 172 | 
         
            +
                            wav,
         
     | 
| 173 | 
         
            +
                            self.cfg.preprocess.sample_rate,
         
     | 
| 174 | 
         
            +
                            add_silence=False,
         
     | 
| 175 | 
         
            +
                            turn_up=not is_silence(wav, self.cfg.preprocess.sample_rate),
         
     | 
| 176 | 
         
            +
                        )
         
     | 
| 177 | 
         
            +
                        os.remove(os.path.join(self.args.output_dir, f"{uid}.pt"))
         
     | 
| 178 | 
         
            +
             
     | 
| 179 | 
         
            +
                    return sorted(output_audio_files)
         
     | 
| 180 | 
         
            +
             
     | 
| 181 | 
         
            +
                # TODO: LEGACY CODE
         
     | 
| 182 | 
         
            +
                def _build_dataloader(self):
         
     | 
| 183 | 
         
            +
                    datasets, collate = self._build_test_dataset()
         
     | 
| 184 | 
         
            +
                    self.test_dataset = datasets(self.args, self.cfg, self.infer_type)
         
     | 
| 185 | 
         
            +
                    self.test_collate = collate(self.cfg)
         
     | 
| 186 | 
         
            +
                    self.test_batch_size = min(
         
     | 
| 187 | 
         
            +
                        self.cfg.train.batch_size, len(self.test_dataset.metadata)
         
     | 
| 188 | 
         
            +
                    )
         
     | 
| 189 | 
         
            +
                    test_dataloader = DataLoader(
         
     | 
| 190 | 
         
            +
                        self.test_dataset,
         
     | 
| 191 | 
         
            +
                        collate_fn=self.test_collate,
         
     | 
| 192 | 
         
            +
                        num_workers=1,
         
     | 
| 193 | 
         
            +
                        batch_size=self.test_batch_size,
         
     | 
| 194 | 
         
            +
                        shuffle=False,
         
     | 
| 195 | 
         
            +
                    )
         
     | 
| 196 | 
         
            +
                    return test_dataloader
         
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
            +
                def __load_model(self, checkpoint_dir: str = None, checkpoint_path: str = None):
         
     | 
| 199 | 
         
            +
                    r"""Load model from checkpoint. If checkpoint_path is None, it will
         
     | 
| 200 | 
         
            +
                    load the latest checkpoint in checkpoint_dir. If checkpoint_path is not
         
     | 
| 201 | 
         
            +
                    None, it will load the checkpoint specified by checkpoint_path. **Only use this
         
     | 
| 202 | 
         
            +
                    method after** ``accelerator.prepare()``.
         
     | 
| 203 | 
         
            +
                    """
         
     | 
| 204 | 
         
            +
                    if checkpoint_path is None:
         
     | 
| 205 | 
         
            +
                        ls = []
         
     | 
| 206 | 
         
            +
                        for i in Path(checkpoint_dir).iterdir():
         
     | 
| 207 | 
         
            +
                            if re.match(r"epoch-\d+_step-\d+_loss-[\d.]+", str(i.stem)):
         
     | 
| 208 | 
         
            +
                                ls.append(i)
         
     | 
| 209 | 
         
            +
                        ls.sort(
         
     | 
| 210 | 
         
            +
                            key=lambda x: int(x.stem.split("_")[-3].split("-")[-1]), reverse=True
         
     | 
| 211 | 
         
            +
                        )
         
     | 
| 212 | 
         
            +
                        checkpoint_path = ls[0]
         
     | 
| 213 | 
         
            +
                    else:
         
     | 
| 214 | 
         
            +
                        checkpoint_path = Path(checkpoint_path)
         
     | 
| 215 | 
         
            +
                    self.accelerator.load_state(str(checkpoint_path))
         
     | 
| 216 | 
         
            +
                    # set epoch and step
         
     | 
| 217 | 
         
            +
                    self.epoch = int(checkpoint_path.stem.split("_")[-3].split("-")[-1])
         
     | 
| 218 | 
         
            +
                    self.step = int(checkpoint_path.stem.split("_")[-2].split("-")[-1])
         
     | 
| 219 | 
         
            +
                    return str(checkpoint_path)
         
     | 
| 220 | 
         
            +
             
     | 
| 221 | 
         
            +
                @staticmethod
         
     | 
| 222 | 
         
            +
                def _set_random_seed(seed):
         
     | 
| 223 | 
         
            +
                    r"""Set random seed for all possible random modules."""
         
     | 
| 224 | 
         
            +
                    random.seed(seed)
         
     | 
| 225 | 
         
            +
                    np.random.seed(seed)
         
     | 
| 226 | 
         
            +
                    torch.random.manual_seed(seed)
         
     | 
| 227 | 
         
            +
             
     | 
| 228 | 
         
            +
                @staticmethod
         
     | 
| 229 | 
         
            +
                def _parse_vocoder(vocoder_dir):
         
     | 
| 230 | 
         
            +
                    r"""Parse vocoder config"""
         
     | 
| 231 | 
         
            +
                    vocoder_dir = os.path.abspath(vocoder_dir)
         
     | 
| 232 | 
         
            +
                    ckpt_list = [ckpt for ckpt in Path(vocoder_dir).glob("*.pt")]
         
     | 
| 233 | 
         
            +
                    ckpt_list.sort(key=lambda x: int(x.stem), reverse=True)
         
     | 
| 234 | 
         
            +
                    ckpt_path = str(ckpt_list[0])
         
     | 
| 235 | 
         
            +
                    vocoder_cfg = load_config(
         
     | 
| 236 | 
         
            +
                        os.path.join(vocoder_dir, "args.json"), lowercase=True
         
     | 
| 237 | 
         
            +
                    )
         
     | 
| 238 | 
         
            +
                    return vocoder_cfg, ckpt_path
         
     | 
| 239 | 
         
            +
             
     | 
| 240 | 
         
            +
                @staticmethod
         
     | 
| 241 | 
         
            +
                def __count_parameters(model):
         
     | 
| 242 | 
         
            +
                    return sum(p.numel() for p in model.parameters())
         
     | 
| 243 | 
         
            +
             
     | 
| 244 | 
         
            +
                def __dump_cfg(self, path):
         
     | 
| 245 | 
         
            +
                    os.makedirs(os.path.dirname(path), exist_ok=True)
         
     | 
| 246 | 
         
            +
                    json5.dump(
         
     | 
| 247 | 
         
            +
                        self.cfg,
         
     | 
| 248 | 
         
            +
                        open(path, "w"),
         
     | 
| 249 | 
         
            +
                        indent=4,
         
     | 
| 250 | 
         
            +
                        sort_keys=True,
         
     | 
| 251 | 
         
            +
                        ensure_ascii=False,
         
     | 
| 252 | 
         
            +
                        quote_keys=True,
         
     | 
| 253 | 
         
            +
                    )
         
     | 
    	
        models/base/new_trainer.py
    ADDED
    
    | 
         @@ -0,0 +1,727 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2023 Amphion.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import json
         
     | 
| 7 | 
         
            +
            import os
         
     | 
| 8 | 
         
            +
            import random
         
     | 
| 9 | 
         
            +
            import shutil
         
     | 
| 10 | 
         
            +
            import time
         
     | 
| 11 | 
         
            +
            from abc import abstractmethod
         
     | 
| 12 | 
         
            +
            from pathlib import Path
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            import accelerate
         
     | 
| 15 | 
         
            +
            import json5
         
     | 
| 16 | 
         
            +
            import numpy as np
         
     | 
| 17 | 
         
            +
            import torch
         
     | 
| 18 | 
         
            +
            from accelerate.logging import get_logger
         
     | 
| 19 | 
         
            +
            from accelerate.utils import ProjectConfiguration
         
     | 
| 20 | 
         
            +
            from torch.utils.data import ConcatDataset, DataLoader
         
     | 
| 21 | 
         
            +
            from tqdm import tqdm
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            from models.base.base_sampler import build_samplers
         
     | 
| 24 | 
         
            +
            from optimizer.optimizers import NoamLR
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            class BaseTrainer(object):
         
     | 
| 28 | 
         
            +
                r"""The base trainer for all tasks. Any trainer should inherit from this class."""
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                def __init__(self, args=None, cfg=None):
         
     | 
| 31 | 
         
            +
                    super().__init__()
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
                    self.args = args
         
     | 
| 34 | 
         
            +
                    self.cfg = cfg
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
                    cfg.exp_name = args.exp_name
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                    # init with accelerate
         
     | 
| 39 | 
         
            +
                    self._init_accelerator()
         
     | 
| 40 | 
         
            +
                    self.accelerator.wait_for_everyone()
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                    # Use accelerate logger for distributed training
         
     | 
| 43 | 
         
            +
                    with self.accelerator.main_process_first():
         
     | 
| 44 | 
         
            +
                        self.logger = get_logger(args.exp_name, log_level=args.log_level)
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                    # Log some info
         
     | 
| 47 | 
         
            +
                    self.logger.info("=" * 56)
         
     | 
| 48 | 
         
            +
                    self.logger.info("||\t\t" + "New training process started." + "\t\t||")
         
     | 
| 49 | 
         
            +
                    self.logger.info("=" * 56)
         
     | 
| 50 | 
         
            +
                    self.logger.info("\n")
         
     | 
| 51 | 
         
            +
                    self.logger.debug(f"Using {args.log_level.upper()} logging level.")
         
     | 
| 52 | 
         
            +
                    self.logger.info(f"Experiment name: {args.exp_name}")
         
     | 
| 53 | 
         
            +
                    self.logger.info(f"Experiment directory: {self.exp_dir}")
         
     | 
| 54 | 
         
            +
                    self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint")
         
     | 
| 55 | 
         
            +
                    if self.accelerator.is_main_process:
         
     | 
| 56 | 
         
            +
                        os.makedirs(self.checkpoint_dir, exist_ok=True)
         
     | 
| 57 | 
         
            +
                    self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}")
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                    # init counts
         
     | 
| 60 | 
         
            +
                    self.batch_count: int = 0
         
     | 
| 61 | 
         
            +
                    self.step: int = 0
         
     | 
| 62 | 
         
            +
                    self.epoch: int = 0
         
     | 
| 63 | 
         
            +
                    self.max_epoch = (
         
     | 
| 64 | 
         
            +
                        self.cfg.train.max_epoch if self.cfg.train.max_epoch > 0 else float("inf")
         
     | 
| 65 | 
         
            +
                    )
         
     | 
| 66 | 
         
            +
                    self.logger.info(
         
     | 
| 67 | 
         
            +
                        "Max epoch: {}".format(
         
     | 
| 68 | 
         
            +
                            self.max_epoch if self.max_epoch < float("inf") else "Unlimited"
         
     | 
| 69 | 
         
            +
                        )
         
     | 
| 70 | 
         
            +
                    )
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                    # Check values
         
     | 
| 73 | 
         
            +
                    if self.accelerator.is_main_process:
         
     | 
| 74 | 
         
            +
                        self.__check_basic_configs()
         
     | 
| 75 | 
         
            +
                        # Set runtime configs
         
     | 
| 76 | 
         
            +
                        self.save_checkpoint_stride = self.cfg.train.save_checkpoint_stride
         
     | 
| 77 | 
         
            +
                        self.checkpoints_path = [
         
     | 
| 78 | 
         
            +
                            [] for _ in range(len(self.save_checkpoint_stride))
         
     | 
| 79 | 
         
            +
                        ]
         
     | 
| 80 | 
         
            +
                        self.keep_last = [
         
     | 
| 81 | 
         
            +
                            i if i > 0 else float("inf") for i in self.cfg.train.keep_last
         
     | 
| 82 | 
         
            +
                        ]
         
     | 
| 83 | 
         
            +
                        self.run_eval = self.cfg.train.run_eval
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
                    # set random seed
         
     | 
| 86 | 
         
            +
                    with self.accelerator.main_process_first():
         
     | 
| 87 | 
         
            +
                        start = time.monotonic_ns()
         
     | 
| 88 | 
         
            +
                        self._set_random_seed(self.cfg.train.random_seed)
         
     | 
| 89 | 
         
            +
                        end = time.monotonic_ns()
         
     | 
| 90 | 
         
            +
                        self.logger.debug(
         
     | 
| 91 | 
         
            +
                            f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
         
     | 
| 92 | 
         
            +
                        )
         
     | 
| 93 | 
         
            +
                        self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
                    # setup data_loader
         
     | 
| 96 | 
         
            +
                    with self.accelerator.main_process_first():
         
     | 
| 97 | 
         
            +
                        self.logger.info("Building dataset...")
         
     | 
| 98 | 
         
            +
                        start = time.monotonic_ns()
         
     | 
| 99 | 
         
            +
                        self.train_dataloader, self.valid_dataloader = self._build_dataloader()
         
     | 
| 100 | 
         
            +
                        end = time.monotonic_ns()
         
     | 
| 101 | 
         
            +
                        self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms")
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                    # setup model
         
     | 
| 104 | 
         
            +
                    with self.accelerator.main_process_first():
         
     | 
| 105 | 
         
            +
                        self.logger.info("Building model...")
         
     | 
| 106 | 
         
            +
                        start = time.monotonic_ns()
         
     | 
| 107 | 
         
            +
                        self.model = self._build_model()
         
     | 
| 108 | 
         
            +
                        end = time.monotonic_ns()
         
     | 
| 109 | 
         
            +
                        self.logger.debug(self.model)
         
     | 
| 110 | 
         
            +
                        self.logger.info(f"Building model done in {(end - start) / 1e6:.2f}ms")
         
     | 
| 111 | 
         
            +
                        self.logger.info(
         
     | 
| 112 | 
         
            +
                            f"Model parameters: {self.__count_parameters(self.model)/1e6:.2f}M"
         
     | 
| 113 | 
         
            +
                        )
         
     | 
| 114 | 
         
            +
                    # optimizer & scheduler
         
     | 
| 115 | 
         
            +
                    with self.accelerator.main_process_first():
         
     | 
| 116 | 
         
            +
                        self.logger.info("Building optimizer and scheduler...")
         
     | 
| 117 | 
         
            +
                        start = time.monotonic_ns()
         
     | 
| 118 | 
         
            +
                        self.optimizer = self._build_optimizer()
         
     | 
| 119 | 
         
            +
                        self.scheduler = self._build_scheduler()
         
     | 
| 120 | 
         
            +
                        end = time.monotonic_ns()
         
     | 
| 121 | 
         
            +
                        self.logger.info(
         
     | 
| 122 | 
         
            +
                            f"Building optimizer and scheduler done in {(end - start) / 1e6:.2f}ms"
         
     | 
| 123 | 
         
            +
                        )
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
                    # accelerate prepare
         
     | 
| 126 | 
         
            +
                    self.logger.info("Initializing accelerate...")
         
     | 
| 127 | 
         
            +
                    start = time.monotonic_ns()
         
     | 
| 128 | 
         
            +
                    self._accelerator_prepare()
         
     | 
| 129 | 
         
            +
                    end = time.monotonic_ns()
         
     | 
| 130 | 
         
            +
                    self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.2f}ms")
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
                    # create criterion
         
     | 
| 133 | 
         
            +
                    with self.accelerator.main_process_first():
         
     | 
| 134 | 
         
            +
                        self.logger.info("Building criterion...")
         
     | 
| 135 | 
         
            +
                        start = time.monotonic_ns()
         
     | 
| 136 | 
         
            +
                        self.criterion = self._build_criterion()
         
     | 
| 137 | 
         
            +
                        end = time.monotonic_ns()
         
     | 
| 138 | 
         
            +
                        self.logger.info(f"Building criterion done in {(end - start) / 1e6:.2f}ms")
         
     | 
| 139 | 
         
            +
             
     | 
| 140 | 
         
            +
                    # Resume or Finetune
         
     | 
| 141 | 
         
            +
                    with self.accelerator.main_process_first():
         
     | 
| 142 | 
         
            +
                        if args.resume:
         
     | 
| 143 | 
         
            +
                            if args.resume_from_ckpt_path == "":
         
     | 
| 144 | 
         
            +
                                ## Automatically resume according to the current exprimental name
         
     | 
| 145 | 
         
            +
                                self.logger.info(
         
     | 
| 146 | 
         
            +
                                    "Automatically resuming from latest checkpoint in {}...".format(
         
     | 
| 147 | 
         
            +
                                        self.checkpoint_dir
         
     | 
| 148 | 
         
            +
                                    )
         
     | 
| 149 | 
         
            +
                                )
         
     | 
| 150 | 
         
            +
                                start = time.monotonic_ns()
         
     | 
| 151 | 
         
            +
                                ckpt_path = self._load_model(
         
     | 
| 152 | 
         
            +
                                    checkpoint_dir=self.checkpoint_dir, resume_type=args.resume_type
         
     | 
| 153 | 
         
            +
                                )
         
     | 
| 154 | 
         
            +
                                end = time.monotonic_ns()
         
     | 
| 155 | 
         
            +
                                self.logger.info(
         
     | 
| 156 | 
         
            +
                                    f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms"
         
     | 
| 157 | 
         
            +
                                )
         
     | 
| 158 | 
         
            +
                                self.checkpoints_path = json.load(
         
     | 
| 159 | 
         
            +
                                    open(os.path.join(ckpt_path, "ckpts.json"), "r")
         
     | 
| 160 | 
         
            +
                                )
         
     | 
| 161 | 
         
            +
                            else:
         
     | 
| 162 | 
         
            +
                                ## Resume from the given checkpoint path
         
     | 
| 163 | 
         
            +
                                if not os.path.exists(args.resume_from_ckpt_path):
         
     | 
| 164 | 
         
            +
                                    raise ValueError(
         
     | 
| 165 | 
         
            +
                                        "[Error] The resumed checkpoint path {} don't exist.".format(
         
     | 
| 166 | 
         
            +
                                            args.resume_from_ckpt_path
         
     | 
| 167 | 
         
            +
                                        )
         
     | 
| 168 | 
         
            +
                                    )
         
     | 
| 169 | 
         
            +
                                self.logger.info(
         
     | 
| 170 | 
         
            +
                                    "Resuming from {}...".format(args.resume_from_ckpt_path)
         
     | 
| 171 | 
         
            +
                                )
         
     | 
| 172 | 
         
            +
                                start = time.monotonic_ns()
         
     | 
| 173 | 
         
            +
                                ckpt_path = self._load_model(
         
     | 
| 174 | 
         
            +
                                    checkpoint_path=args.resume_from_ckpt_path,
         
     | 
| 175 | 
         
            +
                                    resume_type=args.resume_type,
         
     | 
| 176 | 
         
            +
                                )
         
     | 
| 177 | 
         
            +
                                end = time.monotonic_ns()
         
     | 
| 178 | 
         
            +
                                self.logger.info(
         
     | 
| 179 | 
         
            +
                                    f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms"
         
     | 
| 180 | 
         
            +
                                )
         
     | 
| 181 | 
         
            +
             
     | 
| 182 | 
         
            +
                    # save config file path
         
     | 
| 183 | 
         
            +
                    self.config_save_path = os.path.join(self.exp_dir, "args.json")
         
     | 
| 184 | 
         
            +
             
     | 
| 185 | 
         
            +
                def _accelerator_prepare(self):
         
     | 
| 186 | 
         
            +
                    (
         
     | 
| 187 | 
         
            +
                        self.train_dataloader,
         
     | 
| 188 | 
         
            +
                        self.valid_dataloader,
         
     | 
| 189 | 
         
            +
                        self.model,
         
     | 
| 190 | 
         
            +
                        self.optimizer,
         
     | 
| 191 | 
         
            +
                        self.scheduler,
         
     | 
| 192 | 
         
            +
                    ) = self.accelerator.prepare(
         
     | 
| 193 | 
         
            +
                        self.train_dataloader,
         
     | 
| 194 | 
         
            +
                        self.valid_dataloader,
         
     | 
| 195 | 
         
            +
                        self.model,
         
     | 
| 196 | 
         
            +
                        self.optimizer,
         
     | 
| 197 | 
         
            +
                        self.scheduler,
         
     | 
| 198 | 
         
            +
                    )
         
     | 
| 199 | 
         
            +
             
     | 
| 200 | 
         
            +
                ### Following are abstract methods that should be implemented in child classes ###
         
     | 
| 201 | 
         
            +
                @abstractmethod
         
     | 
| 202 | 
         
            +
                def _build_dataset(self):
         
     | 
| 203 | 
         
            +
                    r"""Build dataset for model training/validating/evaluating."""
         
     | 
| 204 | 
         
            +
                    pass
         
     | 
| 205 | 
         
            +
             
     | 
| 206 | 
         
            +
                @staticmethod
         
     | 
| 207 | 
         
            +
                @abstractmethod
         
     | 
| 208 | 
         
            +
                def _build_criterion():
         
     | 
| 209 | 
         
            +
                    r"""Build criterion function for model loss calculation."""
         
     | 
| 210 | 
         
            +
                    pass
         
     | 
| 211 | 
         
            +
             
     | 
| 212 | 
         
            +
                @abstractmethod
         
     | 
| 213 | 
         
            +
                def _build_model(self):
         
     | 
| 214 | 
         
            +
                    r"""Build model for training/validating/evaluating."""
         
     | 
| 215 | 
         
            +
                    pass
         
     | 
| 216 | 
         
            +
             
     | 
| 217 | 
         
            +
                @abstractmethod
         
     | 
| 218 | 
         
            +
                def _forward_step(self, batch):
         
     | 
| 219 | 
         
            +
                    r"""One forward step of the neural network. This abstract method is trying to
         
     | 
| 220 | 
         
            +
                    unify ``_train_step`` and ``_valid_step`` and avoid redundant implementation.
         
     | 
| 221 | 
         
            +
                    However, for special case that using different forward step pattern for
         
     | 
| 222 | 
         
            +
                    training and validating, you could just override this method with ``pass`` and
         
     | 
| 223 | 
         
            +
                    implement ``_train_step`` and ``_valid_step`` separately.
         
     | 
| 224 | 
         
            +
                    """
         
     | 
| 225 | 
         
            +
                    pass
         
     | 
| 226 | 
         
            +
             
     | 
| 227 | 
         
            +
                @abstractmethod
         
     | 
| 228 | 
         
            +
                def _save_auxiliary_states(self):
         
     | 
| 229 | 
         
            +
                    r"""To save some auxiliary states when saving model's ckpt"""
         
     | 
| 230 | 
         
            +
                    pass
         
     | 
| 231 | 
         
            +
             
     | 
| 232 | 
         
            +
                ### Abstract methods end ###
         
     | 
| 233 | 
         
            +
             
     | 
| 234 | 
         
            +
                ### THIS IS MAIN ENTRY ###
         
     | 
| 235 | 
         
            +
                def train_loop(self):
         
     | 
| 236 | 
         
            +
                    r"""Training loop. The public entry of training process."""
         
     | 
| 237 | 
         
            +
                    # Wait everyone to prepare before we move on
         
     | 
| 238 | 
         
            +
                    self.accelerator.wait_for_everyone()
         
     | 
| 239 | 
         
            +
                    # dump config file
         
     | 
| 240 | 
         
            +
                    if self.accelerator.is_main_process:
         
     | 
| 241 | 
         
            +
                        self.__dump_cfg(self.config_save_path)
         
     | 
| 242 | 
         
            +
                    self.model.train()
         
     | 
| 243 | 
         
            +
                    self.optimizer.zero_grad()
         
     | 
| 244 | 
         
            +
                    # Wait to ensure good to go
         
     | 
| 245 | 
         
            +
                    self.accelerator.wait_for_everyone()
         
     | 
| 246 | 
         
            +
                    while self.epoch < self.max_epoch:
         
     | 
| 247 | 
         
            +
                        self.logger.info("\n")
         
     | 
| 248 | 
         
            +
                        self.logger.info("-" * 32)
         
     | 
| 249 | 
         
            +
                        self.logger.info("Epoch {}: ".format(self.epoch))
         
     | 
| 250 | 
         
            +
             
     | 
| 251 | 
         
            +
                        ### TODO: change the return values of _train_epoch() to a loss dict, or (total_loss, loss_dict)
         
     | 
| 252 | 
         
            +
                        ### It's inconvenient for the model with multiple losses
         
     | 
| 253 | 
         
            +
                        # Do training & validating epoch
         
     | 
| 254 | 
         
            +
                        train_loss = self._train_epoch()
         
     | 
| 255 | 
         
            +
                        self.logger.info("  |- Train/Loss: {:.6f}".format(train_loss))
         
     | 
| 256 | 
         
            +
                        valid_loss = self._valid_epoch()
         
     | 
| 257 | 
         
            +
                        self.logger.info("  |- Valid/Loss: {:.6f}".format(valid_loss))
         
     | 
| 258 | 
         
            +
                        self.accelerator.log(
         
     | 
| 259 | 
         
            +
                            {"Epoch/Train Loss": train_loss, "Epoch/Valid Loss": valid_loss},
         
     | 
| 260 | 
         
            +
                            step=self.epoch,
         
     | 
| 261 | 
         
            +
                        )
         
     | 
| 262 | 
         
            +
             
     | 
| 263 | 
         
            +
                        self.accelerator.wait_for_everyone()
         
     | 
| 264 | 
         
            +
                        # TODO: what is scheduler?
         
     | 
| 265 | 
         
            +
                        self.scheduler.step(valid_loss)  # FIXME: use epoch track correct?
         
     | 
| 266 | 
         
            +
             
     | 
| 267 | 
         
            +
                        # Check if hit save_checkpoint_stride and run_eval
         
     | 
| 268 | 
         
            +
                        run_eval = False
         
     | 
| 269 | 
         
            +
                        if self.accelerator.is_main_process:
         
     | 
| 270 | 
         
            +
                            save_checkpoint = False
         
     | 
| 271 | 
         
            +
                            hit_dix = []
         
     | 
| 272 | 
         
            +
                            for i, num in enumerate(self.save_checkpoint_stride):
         
     | 
| 273 | 
         
            +
                                if self.epoch % num == 0:
         
     | 
| 274 | 
         
            +
                                    save_checkpoint = True
         
     | 
| 275 | 
         
            +
                                    hit_dix.append(i)
         
     | 
| 276 | 
         
            +
                                    run_eval |= self.run_eval[i]
         
     | 
| 277 | 
         
            +
             
     | 
| 278 | 
         
            +
                        self.accelerator.wait_for_everyone()
         
     | 
| 279 | 
         
            +
                        if self.accelerator.is_main_process and save_checkpoint:
         
     | 
| 280 | 
         
            +
                            path = os.path.join(
         
     | 
| 281 | 
         
            +
                                self.checkpoint_dir,
         
     | 
| 282 | 
         
            +
                                "epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
         
     | 
| 283 | 
         
            +
                                    self.epoch, self.step, train_loss
         
     | 
| 284 | 
         
            +
                                ),
         
     | 
| 285 | 
         
            +
                            )
         
     | 
| 286 | 
         
            +
                            self.tmp_checkpoint_save_path = path
         
     | 
| 287 | 
         
            +
                            self.accelerator.save_state(path)
         
     | 
| 288 | 
         
            +
                            print(f"save checkpoint in {path}")
         
     | 
| 289 | 
         
            +
                            json.dump(
         
     | 
| 290 | 
         
            +
                                self.checkpoints_path,
         
     | 
| 291 | 
         
            +
                                open(os.path.join(path, "ckpts.json"), "w"),
         
     | 
| 292 | 
         
            +
                                ensure_ascii=False,
         
     | 
| 293 | 
         
            +
                                indent=4,
         
     | 
| 294 | 
         
            +
                            )
         
     | 
| 295 | 
         
            +
                            self._save_auxiliary_states()
         
     | 
| 296 | 
         
            +
             
     | 
| 297 | 
         
            +
                            # Remove old checkpoints
         
     | 
| 298 | 
         
            +
                            to_remove = []
         
     | 
| 299 | 
         
            +
                            for idx in hit_dix:
         
     | 
| 300 | 
         
            +
                                self.checkpoints_path[idx].append(path)
         
     | 
| 301 | 
         
            +
                                while len(self.checkpoints_path[idx]) > self.keep_last[idx]:
         
     | 
| 302 | 
         
            +
                                    to_remove.append((idx, self.checkpoints_path[idx].pop(0)))
         
     | 
| 303 | 
         
            +
             
     | 
| 304 | 
         
            +
                            # Search conflicts
         
     | 
| 305 | 
         
            +
                            total = set()
         
     | 
| 306 | 
         
            +
                            for i in self.checkpoints_path:
         
     | 
| 307 | 
         
            +
                                total |= set(i)
         
     | 
| 308 | 
         
            +
                            do_remove = set()
         
     | 
| 309 | 
         
            +
                            for idx, path in to_remove[::-1]:
         
     | 
| 310 | 
         
            +
                                if path in total:
         
     | 
| 311 | 
         
            +
                                    self.checkpoints_path[idx].insert(0, path)
         
     | 
| 312 | 
         
            +
                                else:
         
     | 
| 313 | 
         
            +
                                    do_remove.add(path)
         
     | 
| 314 | 
         
            +
             
     | 
| 315 | 
         
            +
                            # Remove old checkpoints
         
     | 
| 316 | 
         
            +
                            for path in do_remove:
         
     | 
| 317 | 
         
            +
                                shutil.rmtree(path, ignore_errors=True)
         
     | 
| 318 | 
         
            +
                                self.logger.debug(f"Remove old checkpoint: {path}")
         
     | 
| 319 | 
         
            +
             
     | 
| 320 | 
         
            +
                        self.accelerator.wait_for_everyone()
         
     | 
| 321 | 
         
            +
                        if run_eval:
         
     | 
| 322 | 
         
            +
                            # TODO: run evaluation
         
     | 
| 323 | 
         
            +
                            pass
         
     | 
| 324 | 
         
            +
             
     | 
| 325 | 
         
            +
                        # Update info for each epoch
         
     | 
| 326 | 
         
            +
                        self.epoch += 1
         
     | 
| 327 | 
         
            +
             
     | 
| 328 | 
         
            +
                    # Finish training and save final checkpoint
         
     | 
| 329 | 
         
            +
                    self.accelerator.wait_for_everyone()
         
     | 
| 330 | 
         
            +
                    if self.accelerator.is_main_process:
         
     | 
| 331 | 
         
            +
                        self.accelerator.save_state(
         
     | 
| 332 | 
         
            +
                            os.path.join(
         
     | 
| 333 | 
         
            +
                                self.checkpoint_dir,
         
     | 
| 334 | 
         
            +
                                "final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
         
     | 
| 335 | 
         
            +
                                    self.epoch, self.step, valid_loss
         
     | 
| 336 | 
         
            +
                                ),
         
     | 
| 337 | 
         
            +
                            )
         
     | 
| 338 | 
         
            +
                        )
         
     | 
| 339 | 
         
            +
                        self._save_auxiliary_states()
         
     | 
| 340 | 
         
            +
             
     | 
| 341 | 
         
            +
                    self.accelerator.end_training()
         
     | 
| 342 | 
         
            +
             
     | 
| 343 | 
         
            +
                ### Following are methods that can be used directly in child classes ###
         
     | 
| 344 | 
         
            +
                def _train_epoch(self):
         
     | 
| 345 | 
         
            +
                    r"""Training epoch. Should return average loss of a batch (sample) over
         
     | 
| 346 | 
         
            +
                    one epoch. See ``train_loop`` for usage.
         
     | 
| 347 | 
         
            +
                    """
         
     | 
| 348 | 
         
            +
                    self.model.train()
         
     | 
| 349 | 
         
            +
                    epoch_sum_loss: float = 0.0
         
     | 
| 350 | 
         
            +
                    epoch_step: int = 0
         
     | 
| 351 | 
         
            +
                    for batch in tqdm(
         
     | 
| 352 | 
         
            +
                        self.train_dataloader,
         
     | 
| 353 | 
         
            +
                        desc=f"Training Epoch {self.epoch}",
         
     | 
| 354 | 
         
            +
                        unit="batch",
         
     | 
| 355 | 
         
            +
                        colour="GREEN",
         
     | 
| 356 | 
         
            +
                        leave=False,
         
     | 
| 357 | 
         
            +
                        dynamic_ncols=True,
         
     | 
| 358 | 
         
            +
                        smoothing=0.04,
         
     | 
| 359 | 
         
            +
                        disable=not self.accelerator.is_main_process,
         
     | 
| 360 | 
         
            +
                    ):
         
     | 
| 361 | 
         
            +
                        # Do training step and BP
         
     | 
| 362 | 
         
            +
                        with self.accelerator.accumulate(self.model):
         
     | 
| 363 | 
         
            +
                            loss = self._train_step(batch)
         
     | 
| 364 | 
         
            +
                            self.accelerator.backward(loss)
         
     | 
| 365 | 
         
            +
                            self.optimizer.step()
         
     | 
| 366 | 
         
            +
                            self.optimizer.zero_grad()
         
     | 
| 367 | 
         
            +
                        self.batch_count += 1
         
     | 
| 368 | 
         
            +
             
     | 
| 369 | 
         
            +
                        # Update info for each step
         
     | 
| 370 | 
         
            +
                        # TODO: step means BP counts or batch counts?
         
     | 
| 371 | 
         
            +
                        if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
         
     | 
| 372 | 
         
            +
                            epoch_sum_loss += loss
         
     | 
| 373 | 
         
            +
                            self.accelerator.log(
         
     | 
| 374 | 
         
            +
                                {
         
     | 
| 375 | 
         
            +
                                    "Step/Train Loss": loss,
         
     | 
| 376 | 
         
            +
                                    "Step/Learning Rate": self.optimizer.param_groups[0]["lr"],
         
     | 
| 377 | 
         
            +
                                },
         
     | 
| 378 | 
         
            +
                                step=self.step,
         
     | 
| 379 | 
         
            +
                            )
         
     | 
| 380 | 
         
            +
                            self.step += 1
         
     | 
| 381 | 
         
            +
                            epoch_step += 1
         
     | 
| 382 | 
         
            +
             
     | 
| 383 | 
         
            +
                    self.accelerator.wait_for_everyone()
         
     | 
| 384 | 
         
            +
                    return (
         
     | 
| 385 | 
         
            +
                        epoch_sum_loss
         
     | 
| 386 | 
         
            +
                        / len(self.train_dataloader)
         
     | 
| 387 | 
         
            +
                        * self.cfg.train.gradient_accumulation_step
         
     | 
| 388 | 
         
            +
                    )
         
     | 
| 389 | 
         
            +
             
     | 
| 390 | 
         
            +
                @torch.inference_mode()
         
     | 
| 391 | 
         
            +
                def _valid_epoch(self):
         
     | 
| 392 | 
         
            +
                    r"""Testing epoch. Should return average loss of a batch (sample) over
         
     | 
| 393 | 
         
            +
                    one epoch. See ``train_loop`` for usage.
         
     | 
| 394 | 
         
            +
                    """
         
     | 
| 395 | 
         
            +
                    self.model.eval()
         
     | 
| 396 | 
         
            +
                    epoch_sum_loss = 0.0
         
     | 
| 397 | 
         
            +
                    for batch in tqdm(
         
     | 
| 398 | 
         
            +
                        self.valid_dataloader,
         
     | 
| 399 | 
         
            +
                        desc=f"Validating Epoch {self.epoch}",
         
     | 
| 400 | 
         
            +
                        unit="batch",
         
     | 
| 401 | 
         
            +
                        colour="GREEN",
         
     | 
| 402 | 
         
            +
                        leave=False,
         
     | 
| 403 | 
         
            +
                        dynamic_ncols=True,
         
     | 
| 404 | 
         
            +
                        smoothing=0.04,
         
     | 
| 405 | 
         
            +
                        disable=not self.accelerator.is_main_process,
         
     | 
| 406 | 
         
            +
                    ):
         
     | 
| 407 | 
         
            +
                        batch_loss = self._valid_step(batch)
         
     | 
| 408 | 
         
            +
                        epoch_sum_loss += batch_loss.item()
         
     | 
| 409 | 
         
            +
             
     | 
| 410 | 
         
            +
                    self.accelerator.wait_for_everyone()
         
     | 
| 411 | 
         
            +
                    return epoch_sum_loss / len(self.valid_dataloader)
         
     | 
| 412 | 
         
            +
             
     | 
| 413 | 
         
            +
                def _train_step(self, batch):
         
     | 
| 414 | 
         
            +
                    r"""Training forward step. Should return average loss of a sample over
         
     | 
| 415 | 
         
            +
                    one batch. Provoke ``_forward_step`` is recommended except for special case.
         
     | 
| 416 | 
         
            +
                    See ``_train_epoch`` for usage.
         
     | 
| 417 | 
         
            +
                    """
         
     | 
| 418 | 
         
            +
                    return self._forward_step(batch)
         
     | 
| 419 | 
         
            +
             
     | 
| 420 | 
         
            +
                @torch.inference_mode()
         
     | 
| 421 | 
         
            +
                def _valid_step(self, batch):
         
     | 
| 422 | 
         
            +
                    r"""Testing forward step. Should return average loss of a sample over
         
     | 
| 423 | 
         
            +
                    one batch. Provoke ``_forward_step`` is recommended except for special case.
         
     | 
| 424 | 
         
            +
                    See ``_test_epoch`` for usage.
         
     | 
| 425 | 
         
            +
                    """
         
     | 
| 426 | 
         
            +
                    return self._forward_step(batch)
         
     | 
| 427 | 
         
            +
             
     | 
| 428 | 
         
            +
                def _load_model(
         
     | 
| 429 | 
         
            +
                    self,
         
     | 
| 430 | 
         
            +
                    checkpoint_dir: str = None,
         
     | 
| 431 | 
         
            +
                    checkpoint_path: str = None,
         
     | 
| 432 | 
         
            +
                    resume_type: str = "",
         
     | 
| 433 | 
         
            +
                ):
         
     | 
| 434 | 
         
            +
                    r"""Load model from checkpoint. If checkpoint_path is None, it will
         
     | 
| 435 | 
         
            +
                    load the latest checkpoint in checkpoint_dir. If checkpoint_path is not
         
     | 
| 436 | 
         
            +
                    None, it will load the checkpoint specified by checkpoint_path. **Only use this
         
     | 
| 437 | 
         
            +
                    method after** ``accelerator.prepare()``.
         
     | 
| 438 | 
         
            +
                    """
         
     | 
| 439 | 
         
            +
                    if checkpoint_path is None:
         
     | 
| 440 | 
         
            +
                        ls = [str(i) for i in Path(checkpoint_dir).glob("*")]
         
     | 
| 441 | 
         
            +
                        ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True)
         
     | 
| 442 | 
         
            +
                        checkpoint_path = ls[0]
         
     | 
| 443 | 
         
            +
                        self.logger.info("Resume from {}...".format(checkpoint_path))
         
     | 
| 444 | 
         
            +
             
     | 
| 445 | 
         
            +
                    if resume_type in ["resume", ""]:
         
     | 
| 446 | 
         
            +
                        # Load all the things, including model weights, optimizer, scheduler, and random states.
         
     | 
| 447 | 
         
            +
                        self.accelerator.load_state(input_dir=checkpoint_path)
         
     | 
| 448 | 
         
            +
             
     | 
| 449 | 
         
            +
                        # set epoch and step
         
     | 
| 450 | 
         
            +
                        self.epoch = int(checkpoint_path.split("_")[-3].split("-")[-1]) + 1
         
     | 
| 451 | 
         
            +
                        self.step = int(checkpoint_path.split("_")[-2].split("-")[-1]) + 1
         
     | 
| 452 | 
         
            +
             
     | 
| 453 | 
         
            +
                    elif resume_type == "finetune":
         
     | 
| 454 | 
         
            +
                        # Load only the model weights
         
     | 
| 455 | 
         
            +
                        accelerate.load_checkpoint_and_dispatch(
         
     | 
| 456 | 
         
            +
                            self.accelerator.unwrap_model(self.model),
         
     | 
| 457 | 
         
            +
                            os.path.join(checkpoint_path, "pytorch_model.bin"),
         
     | 
| 458 | 
         
            +
                        )
         
     | 
| 459 | 
         
            +
                        self.logger.info("Load model weights for finetune...")
         
     | 
| 460 | 
         
            +
             
     | 
| 461 | 
         
            +
                    else:
         
     | 
| 462 | 
         
            +
                        raise ValueError("Resume_type must be `resume` or `finetune`.")
         
     | 
| 463 | 
         
            +
             
     | 
| 464 | 
         
            +
                    return checkpoint_path
         
     | 
| 465 | 
         
            +
             
     | 
| 466 | 
         
            +
                def _build_dataloader(self):
         
     | 
| 467 | 
         
            +
                    Dataset, Collator = self._build_dataset()
         
     | 
| 468 | 
         
            +
             
     | 
| 469 | 
         
            +
                    # build dataset instance for each dataset and combine them by ConcatDataset
         
     | 
| 470 | 
         
            +
                    datasets_list = []
         
     | 
| 471 | 
         
            +
                    for dataset in self.cfg.dataset:
         
     | 
| 472 | 
         
            +
                        subdataset = Dataset(self.cfg, dataset, is_valid=False)
         
     | 
| 473 | 
         
            +
                        datasets_list.append(subdataset)
         
     | 
| 474 | 
         
            +
                    train_dataset = ConcatDataset(datasets_list)
         
     | 
| 475 | 
         
            +
                    train_collate = Collator(self.cfg)
         
     | 
| 476 | 
         
            +
                    _, batch_sampler = build_samplers(train_dataset, self.cfg, self.logger, "train")
         
     | 
| 477 | 
         
            +
                    self.logger.debug(f"train batch_sampler: {list(batch_sampler)}")
         
     | 
| 478 | 
         
            +
                    self.logger.debug(f"length: {train_dataset.cumulative_sizes}")
         
     | 
| 479 | 
         
            +
                    # TODO: use config instead of (sampler, shuffle, drop_last, batch_size)
         
     | 
| 480 | 
         
            +
                    train_loader = DataLoader(
         
     | 
| 481 | 
         
            +
                        train_dataset,
         
     | 
| 482 | 
         
            +
                        # shuffle=True,
         
     | 
| 483 | 
         
            +
                        collate_fn=train_collate,
         
     | 
| 484 | 
         
            +
                        batch_sampler=batch_sampler,
         
     | 
| 485 | 
         
            +
                        num_workers=self.cfg.train.dataloader.num_worker,
         
     | 
| 486 | 
         
            +
                        pin_memory=self.cfg.train.dataloader.pin_memory,
         
     | 
| 487 | 
         
            +
                    )
         
     | 
| 488 | 
         
            +
             
     | 
| 489 | 
         
            +
                    # Build valid dataloader
         
     | 
| 490 | 
         
            +
                    datasets_list = []
         
     | 
| 491 | 
         
            +
                    for dataset in self.cfg.dataset:
         
     | 
| 492 | 
         
            +
                        subdataset = Dataset(self.cfg, dataset, is_valid=True)
         
     | 
| 493 | 
         
            +
                        datasets_list.append(subdataset)
         
     | 
| 494 | 
         
            +
                    valid_dataset = ConcatDataset(datasets_list)
         
     | 
| 495 | 
         
            +
                    valid_collate = Collator(self.cfg)
         
     | 
| 496 | 
         
            +
                    _, batch_sampler = build_samplers(valid_dataset, self.cfg, self.logger, "valid")
         
     | 
| 497 | 
         
            +
                    self.logger.debug(f"valid batch_sampler: {list(batch_sampler)}")
         
     | 
| 498 | 
         
            +
                    self.logger.debug(f"length: {valid_dataset.cumulative_sizes}")
         
     | 
| 499 | 
         
            +
                    valid_loader = DataLoader(
         
     | 
| 500 | 
         
            +
                        valid_dataset,
         
     | 
| 501 | 
         
            +
                        collate_fn=valid_collate,
         
     | 
| 502 | 
         
            +
                        batch_sampler=batch_sampler,
         
     | 
| 503 | 
         
            +
                        num_workers=self.cfg.train.dataloader.num_worker,
         
     | 
| 504 | 
         
            +
                        pin_memory=self.cfg.train.dataloader.pin_memory,
         
     | 
| 505 | 
         
            +
                    )
         
     | 
| 506 | 
         
            +
                    return train_loader, valid_loader
         
     | 
| 507 | 
         
            +
             
     | 
| 508 | 
         
            +
                @staticmethod
         
     | 
| 509 | 
         
            +
                def _set_random_seed(seed):
         
     | 
| 510 | 
         
            +
                    r"""Set random seed for all possible random modules."""
         
     | 
| 511 | 
         
            +
                    random.seed(seed)
         
     | 
| 512 | 
         
            +
                    np.random.seed(seed)
         
     | 
| 513 | 
         
            +
                    torch.random.manual_seed(seed)
         
     | 
| 514 | 
         
            +
             
     | 
| 515 | 
         
            +
                def _check_nan(self, loss, y_pred, y_gt):
         
     | 
| 516 | 
         
            +
                    if torch.any(torch.isnan(loss)):
         
     | 
| 517 | 
         
            +
                        self.logger.error("Fatal Error: Training is down since loss has Nan!")
         
     | 
| 518 | 
         
            +
                        self.logger.error("loss = {:.6f}".format(loss.item()), in_order=True)
         
     | 
| 519 | 
         
            +
             
     | 
| 520 | 
         
            +
                        ### y_pred ###
         
     | 
| 521 | 
         
            +
                        if torch.any(torch.isnan(y_pred)):
         
     | 
| 522 | 
         
            +
                            self.logger.error(
         
     | 
| 523 | 
         
            +
                                f"y_pred has Nan: {torch.any(torch.isnan(y_pred))}", in_order=True
         
     | 
| 524 | 
         
            +
                            )
         
     | 
| 525 | 
         
            +
                            self.logger.error(f"y_pred: {y_pred}", in_order=True)
         
     | 
| 526 | 
         
            +
                        else:
         
     | 
| 527 | 
         
            +
                            self.logger.debug(
         
     | 
| 528 | 
         
            +
                                f"y_pred has Nan: {torch.any(torch.isnan(y_pred))}", in_order=True
         
     | 
| 529 | 
         
            +
                            )
         
     | 
| 530 | 
         
            +
                            self.logger.debug(f"y_pred: {y_pred}", in_order=True)
         
     | 
| 531 | 
         
            +
             
     | 
| 532 | 
         
            +
                        ### y_gt ###
         
     | 
| 533 | 
         
            +
                        if torch.any(torch.isnan(y_gt)):
         
     | 
| 534 | 
         
            +
                            self.logger.error(
         
     | 
| 535 | 
         
            +
                                f"y_gt has Nan: {torch.any(torch.isnan(y_gt))}", in_order=True
         
     | 
| 536 | 
         
            +
                            )
         
     | 
| 537 | 
         
            +
                            self.logger.error(f"y_gt: {y_gt}", in_order=True)
         
     | 
| 538 | 
         
            +
                        else:
         
     | 
| 539 | 
         
            +
                            self.logger.debug(
         
     | 
| 540 | 
         
            +
                                f"y_gt has nan: {torch.any(torch.isnan(y_gt))}", in_order=True
         
     | 
| 541 | 
         
            +
                            )
         
     | 
| 542 | 
         
            +
                            self.logger.debug(f"y_gt: {y_gt}", in_order=True)
         
     | 
| 543 | 
         
            +
             
     | 
| 544 | 
         
            +
                        self.accelerator.end_training()
         
     | 
| 545 | 
         
            +
                        raise RuntimeError("Loss has Nan! See log for more info.")
         
     | 
| 546 | 
         
            +
             
     | 
| 547 | 
         
            +
                ### Protected methods end ###
         
     | 
| 548 | 
         
            +
             
     | 
| 549 | 
         
            +
                ## Following are private methods ##
         
     | 
| 550 | 
         
            +
                def _build_optimizer(self):
         
     | 
| 551 | 
         
            +
                    r"""Build optimizer for model."""
         
     | 
| 552 | 
         
            +
                    # Make case-insensitive matching
         
     | 
| 553 | 
         
            +
                    if self.cfg.train.optimizer.lower() == "adadelta":
         
     | 
| 554 | 
         
            +
                        optimizer = torch.optim.Adadelta(
         
     | 
| 555 | 
         
            +
                            self.model.parameters(), **self.cfg.train.adadelta
         
     | 
| 556 | 
         
            +
                        )
         
     | 
| 557 | 
         
            +
                        self.logger.info("Using Adadelta optimizer.")
         
     | 
| 558 | 
         
            +
                    elif self.cfg.train.optimizer.lower() == "adagrad":
         
     | 
| 559 | 
         
            +
                        optimizer = torch.optim.Adagrad(
         
     | 
| 560 | 
         
            +
                            self.model.parameters(), **self.cfg.train.adagrad
         
     | 
| 561 | 
         
            +
                        )
         
     | 
| 562 | 
         
            +
                        self.logger.info("Using Adagrad optimizer.")
         
     | 
| 563 | 
         
            +
                    elif self.cfg.train.optimizer.lower() == "adam":
         
     | 
| 564 | 
         
            +
                        optimizer = torch.optim.Adam(self.model.parameters(), **self.cfg.train.adam)
         
     | 
| 565 | 
         
            +
                        self.logger.info("Using Adam optimizer.")
         
     | 
| 566 | 
         
            +
                    elif self.cfg.train.optimizer.lower() == "adamw":
         
     | 
| 567 | 
         
            +
                        optimizer = torch.optim.AdamW(
         
     | 
| 568 | 
         
            +
                            self.model.parameters(), **self.cfg.train.adamw
         
     | 
| 569 | 
         
            +
                        )
         
     | 
| 570 | 
         
            +
                    elif self.cfg.train.optimizer.lower() == "sparseadam":
         
     | 
| 571 | 
         
            +
                        optimizer = torch.optim.SparseAdam(
         
     | 
| 572 | 
         
            +
                            self.model.parameters(), **self.cfg.train.sparseadam
         
     | 
| 573 | 
         
            +
                        )
         
     | 
| 574 | 
         
            +
                    elif self.cfg.train.optimizer.lower() == "adamax":
         
     | 
| 575 | 
         
            +
                        optimizer = torch.optim.Adamax(
         
     | 
| 576 | 
         
            +
                            self.model.parameters(), **self.cfg.train.adamax
         
     | 
| 577 | 
         
            +
                        )
         
     | 
| 578 | 
         
            +
                    elif self.cfg.train.optimizer.lower() == "asgd":
         
     | 
| 579 | 
         
            +
                        optimizer = torch.optim.ASGD(self.model.parameters(), **self.cfg.train.asgd)
         
     | 
| 580 | 
         
            +
                    elif self.cfg.train.optimizer.lower() == "lbfgs":
         
     | 
| 581 | 
         
            +
                        optimizer = torch.optim.LBFGS(
         
     | 
| 582 | 
         
            +
                            self.model.parameters(), **self.cfg.train.lbfgs
         
     | 
| 583 | 
         
            +
                        )
         
     | 
| 584 | 
         
            +
                    elif self.cfg.train.optimizer.lower() == "nadam":
         
     | 
| 585 | 
         
            +
                        optimizer = torch.optim.NAdam(
         
     | 
| 586 | 
         
            +
                            self.model.parameters(), **self.cfg.train.nadam
         
     | 
| 587 | 
         
            +
                        )
         
     | 
| 588 | 
         
            +
                    elif self.cfg.train.optimizer.lower() == "radam":
         
     | 
| 589 | 
         
            +
                        optimizer = torch.optim.RAdam(
         
     | 
| 590 | 
         
            +
                            self.model.parameters(), **self.cfg.train.radam
         
     | 
| 591 | 
         
            +
                        )
         
     | 
| 592 | 
         
            +
                    elif self.cfg.train.optimizer.lower() == "rmsprop":
         
     | 
| 593 | 
         
            +
                        optimizer = torch.optim.RMSprop(
         
     | 
| 594 | 
         
            +
                            self.model.parameters(), **self.cfg.train.rmsprop
         
     | 
| 595 | 
         
            +
                        )
         
     | 
| 596 | 
         
            +
                    elif self.cfg.train.optimizer.lower() == "rprop":
         
     | 
| 597 | 
         
            +
                        optimizer = torch.optim.Rprop(
         
     | 
| 598 | 
         
            +
                            self.model.parameters(), **self.cfg.train.rprop
         
     | 
| 599 | 
         
            +
                        )
         
     | 
| 600 | 
         
            +
                    elif self.cfg.train.optimizer.lower() == "sgd":
         
     | 
| 601 | 
         
            +
                        optimizer = torch.optim.SGD(self.model.parameters(), **self.cfg.train.sgd)
         
     | 
| 602 | 
         
            +
                    else:
         
     | 
| 603 | 
         
            +
                        raise NotImplementedError(
         
     | 
| 604 | 
         
            +
                            f"Optimizer {self.cfg.train.optimizer} not supported yet!"
         
     | 
| 605 | 
         
            +
                        )
         
     | 
| 606 | 
         
            +
                    return optimizer
         
     | 
| 607 | 
         
            +
             
     | 
| 608 | 
         
            +
                def _build_scheduler(self):
         
     | 
| 609 | 
         
            +
                    r"""Build scheduler for optimizer."""
         
     | 
| 610 | 
         
            +
                    # Make case-insensitive matching
         
     | 
| 611 | 
         
            +
                    if self.cfg.train.scheduler.lower() == "lambdalr":
         
     | 
| 612 | 
         
            +
                        scheduler = torch.optim.lr_scheduler.LambdaLR(
         
     | 
| 613 | 
         
            +
                            self.optimizer, **self.cfg.train.lambdalr
         
     | 
| 614 | 
         
            +
                        )
         
     | 
| 615 | 
         
            +
                    elif self.cfg.train.scheduler.lower() == "multiplicativelr":
         
     | 
| 616 | 
         
            +
                        scheduler = torch.optim.lr_scheduler.MultiplicativeLR(
         
     | 
| 617 | 
         
            +
                            self.optimizer, **self.cfg.train.multiplicativelr
         
     | 
| 618 | 
         
            +
                        )
         
     | 
| 619 | 
         
            +
                    elif self.cfg.train.scheduler.lower() == "steplr":
         
     | 
| 620 | 
         
            +
                        scheduler = torch.optim.lr_scheduler.StepLR(
         
     | 
| 621 | 
         
            +
                            self.optimizer, **self.cfg.train.steplr
         
     | 
| 622 | 
         
            +
                        )
         
     | 
| 623 | 
         
            +
                    elif self.cfg.train.scheduler.lower() == "multisteplr":
         
     | 
| 624 | 
         
            +
                        scheduler = torch.optim.lr_scheduler.MultiStepLR(
         
     | 
| 625 | 
         
            +
                            self.optimizer, **self.cfg.train.multisteplr
         
     | 
| 626 | 
         
            +
                        )
         
     | 
| 627 | 
         
            +
                    elif self.cfg.train.scheduler.lower() == "constantlr":
         
     | 
| 628 | 
         
            +
                        scheduler = torch.optim.lr_scheduler.ConstantLR(
         
     | 
| 629 | 
         
            +
                            self.optimizer, **self.cfg.train.constantlr
         
     | 
| 630 | 
         
            +
                        )
         
     | 
| 631 | 
         
            +
                    elif self.cfg.train.scheduler.lower() == "linearlr":
         
     | 
| 632 | 
         
            +
                        scheduler = torch.optim.lr_scheduler.LinearLR(
         
     | 
| 633 | 
         
            +
                            self.optimizer, **self.cfg.train.linearlr
         
     | 
| 634 | 
         
            +
                        )
         
     | 
| 635 | 
         
            +
                    elif self.cfg.train.scheduler.lower() == "exponentiallr":
         
     | 
| 636 | 
         
            +
                        scheduler = torch.optim.lr_scheduler.ExponentialLR(
         
     | 
| 637 | 
         
            +
                            self.optimizer, **self.cfg.train.exponentiallr
         
     | 
| 638 | 
         
            +
                        )
         
     | 
| 639 | 
         
            +
                    elif self.cfg.train.scheduler.lower() == "polynomiallr":
         
     | 
| 640 | 
         
            +
                        scheduler = torch.optim.lr_scheduler.PolynomialLR(
         
     | 
| 641 | 
         
            +
                            self.optimizer, **self.cfg.train.polynomiallr
         
     | 
| 642 | 
         
            +
                        )
         
     | 
| 643 | 
         
            +
                    elif self.cfg.train.scheduler.lower() == "cosineannealinglr":
         
     | 
| 644 | 
         
            +
                        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
         
     | 
| 645 | 
         
            +
                            self.optimizer, **self.cfg.train.cosineannealinglr
         
     | 
| 646 | 
         
            +
                        )
         
     | 
| 647 | 
         
            +
                    elif self.cfg.train.scheduler.lower() == "sequentiallr":
         
     | 
| 648 | 
         
            +
                        scheduler = torch.optim.lr_scheduler.SequentialLR(
         
     | 
| 649 | 
         
            +
                            self.optimizer, **self.cfg.train.sequentiallr
         
     | 
| 650 | 
         
            +
                        )
         
     | 
| 651 | 
         
            +
                    elif self.cfg.train.scheduler.lower() == "reducelronplateau":
         
     | 
| 652 | 
         
            +
                        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
         
     | 
| 653 | 
         
            +
                            self.optimizer, **self.cfg.train.reducelronplateau
         
     | 
| 654 | 
         
            +
                        )
         
     | 
| 655 | 
         
            +
                    elif self.cfg.train.scheduler.lower() == "cycliclr":
         
     | 
| 656 | 
         
            +
                        scheduler = torch.optim.lr_scheduler.CyclicLR(
         
     | 
| 657 | 
         
            +
                            self.optimizer, **self.cfg.train.cycliclr
         
     | 
| 658 | 
         
            +
                        )
         
     | 
| 659 | 
         
            +
                    elif self.cfg.train.scheduler.lower() == "onecyclelr":
         
     | 
| 660 | 
         
            +
                        scheduler = torch.optim.lr_scheduler.OneCycleLR(
         
     | 
| 661 | 
         
            +
                            self.optimizer, **self.cfg.train.onecyclelr
         
     | 
| 662 | 
         
            +
                        )
         
     | 
| 663 | 
         
            +
                    elif self.cfg.train.scheduler.lower() == "cosineannearingwarmrestarts":
         
     | 
| 664 | 
         
            +
                        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
         
     | 
| 665 | 
         
            +
                            self.optimizer, **self.cfg.train.cosineannearingwarmrestarts
         
     | 
| 666 | 
         
            +
                        )
         
     | 
| 667 | 
         
            +
                    elif self.cfg.train.scheduler.lower() == "noamlr":
         
     | 
| 668 | 
         
            +
                        scheduler = NoamLR(self.optimizer, **self.cfg.train.lr_scheduler)
         
     | 
| 669 | 
         
            +
                    else:
         
     | 
| 670 | 
         
            +
                        raise NotImplementedError(
         
     | 
| 671 | 
         
            +
                            f"Scheduler {self.cfg.train.scheduler} not supported yet!"
         
     | 
| 672 | 
         
            +
                        )
         
     | 
| 673 | 
         
            +
                    return scheduler
         
     | 
| 674 | 
         
            +
             
     | 
| 675 | 
         
            +
                def _init_accelerator(self):
         
     | 
| 676 | 
         
            +
                    self.exp_dir = os.path.join(
         
     | 
| 677 | 
         
            +
                        os.path.abspath(self.cfg.log_dir), self.args.exp_name
         
     | 
| 678 | 
         
            +
                    )
         
     | 
| 679 | 
         
            +
                    project_config = ProjectConfiguration(
         
     | 
| 680 | 
         
            +
                        project_dir=self.exp_dir,
         
     | 
| 681 | 
         
            +
                        logging_dir=os.path.join(self.exp_dir, "log"),
         
     | 
| 682 | 
         
            +
                    )
         
     | 
| 683 | 
         
            +
                    self.accelerator = accelerate.Accelerator(
         
     | 
| 684 | 
         
            +
                        gradient_accumulation_steps=self.cfg.train.gradient_accumulation_step,
         
     | 
| 685 | 
         
            +
                        log_with=self.cfg.train.tracker,
         
     | 
| 686 | 
         
            +
                        project_config=project_config,
         
     | 
| 687 | 
         
            +
                    )
         
     | 
| 688 | 
         
            +
                    if self.accelerator.is_main_process:
         
     | 
| 689 | 
         
            +
                        os.makedirs(project_config.project_dir, exist_ok=True)
         
     | 
| 690 | 
         
            +
                        os.makedirs(project_config.logging_dir, exist_ok=True)
         
     | 
| 691 | 
         
            +
                    with self.accelerator.main_process_first():
         
     | 
| 692 | 
         
            +
                        self.accelerator.init_trackers(self.args.exp_name)
         
     | 
| 693 | 
         
            +
             
     | 
| 694 | 
         
            +
                def __check_basic_configs(self):
         
     | 
| 695 | 
         
            +
                    if self.cfg.train.gradient_accumulation_step <= 0:
         
     | 
| 696 | 
         
            +
                        self.logger.fatal("Invalid gradient_accumulation_step value!")
         
     | 
| 697 | 
         
            +
                        self.logger.error(
         
     | 
| 698 | 
         
            +
                            f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
         
     | 
| 699 | 
         
            +
                        )
         
     | 
| 700 | 
         
            +
                        self.accelerator.end_training()
         
     | 
| 701 | 
         
            +
                        raise ValueError(
         
     | 
| 702 | 
         
            +
                            f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
         
     | 
| 703 | 
         
            +
                        )
         
     | 
| 704 | 
         
            +
                    # TODO: check other values
         
     | 
| 705 | 
         
            +
             
     | 
| 706 | 
         
            +
                @staticmethod
         
     | 
| 707 | 
         
            +
                def __count_parameters(model):
         
     | 
| 708 | 
         
            +
                    model_param = 0.0
         
     | 
| 709 | 
         
            +
                    if isinstance(model, dict):
         
     | 
| 710 | 
         
            +
                        for key, value in model.items():
         
     | 
| 711 | 
         
            +
                            model_param += sum(p.numel() for p in model[key].parameters())
         
     | 
| 712 | 
         
            +
                    else:
         
     | 
| 713 | 
         
            +
                        model_param = sum(p.numel() for p in model.parameters())
         
     | 
| 714 | 
         
            +
                    return model_param
         
     | 
| 715 | 
         
            +
             
     | 
| 716 | 
         
            +
                def __dump_cfg(self, path):
         
     | 
| 717 | 
         
            +
                    os.makedirs(os.path.dirname(path), exist_ok=True)
         
     | 
| 718 | 
         
            +
                    json5.dump(
         
     | 
| 719 | 
         
            +
                        self.cfg,
         
     | 
| 720 | 
         
            +
                        open(path, "w"),
         
     | 
| 721 | 
         
            +
                        indent=4,
         
     | 
| 722 | 
         
            +
                        sort_keys=True,
         
     | 
| 723 | 
         
            +
                        ensure_ascii=False,
         
     | 
| 724 | 
         
            +
                        quote_keys=True,
         
     | 
| 725 | 
         
            +
                    )
         
     | 
| 726 | 
         
            +
             
     | 
| 727 | 
         
            +
                ### Private methods end ###
         
     | 
    	
        models/codec/__init__.py
    ADDED
    
    | 
         
            File without changes
         
     | 
    	
        models/codec/amphion_codec/codec.py
    ADDED
    
    | 
         @@ -0,0 +1,427 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2024 Amphion.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import math
         
     | 
| 7 | 
         
            +
            import numpy as np
         
     | 
| 8 | 
         
            +
            import torch
         
     | 
| 9 | 
         
            +
            import torch.nn as nn
         
     | 
| 10 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 11 | 
         
            +
            from einops import rearrange
         
     | 
| 12 | 
         
            +
            from torch.nn.utils import weight_norm
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            from models.codec.amphion_codec.quantize import (
         
     | 
| 15 | 
         
            +
                ResidualVQ,
         
     | 
| 16 | 
         
            +
                VectorQuantize,
         
     | 
| 17 | 
         
            +
                FactorizedVectorQuantize,
         
     | 
| 18 | 
         
            +
                LookupFreeQuantize,
         
     | 
| 19 | 
         
            +
            )
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            from models.codec.amphion_codec.vocos import Vocos
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            def WNConv1d(*args, **kwargs):
         
     | 
| 25 | 
         
            +
                return weight_norm(nn.Conv1d(*args, **kwargs))
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            def WNConvTranspose1d(*args, **kwargs):
         
     | 
| 29 | 
         
            +
                return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
            # Scripting this brings model speed up 1.4x
         
     | 
| 33 | 
         
            +
            @torch.jit.script
         
     | 
| 34 | 
         
            +
            def snake(x, alpha):
         
     | 
| 35 | 
         
            +
                shape = x.shape
         
     | 
| 36 | 
         
            +
                x = x.reshape(shape[0], shape[1], -1)
         
     | 
| 37 | 
         
            +
                x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
         
     | 
| 38 | 
         
            +
                x = x.reshape(shape)
         
     | 
| 39 | 
         
            +
                return x
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
            class Snake1d(nn.Module):
         
     | 
| 43 | 
         
            +
                def __init__(self, channels):
         
     | 
| 44 | 
         
            +
                    super().__init__()
         
     | 
| 45 | 
         
            +
                    self.alpha = nn.Parameter(torch.ones(1, channels, 1))
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                def forward(self, x):
         
     | 
| 48 | 
         
            +
                    return snake(x, self.alpha)
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
            def init_weights(m):
         
     | 
| 52 | 
         
            +
                if isinstance(m, nn.Conv1d):
         
     | 
| 53 | 
         
            +
                    nn.init.trunc_normal_(m.weight, std=0.02)
         
     | 
| 54 | 
         
            +
                    nn.init.constant_(m.bias, 0)
         
     | 
| 55 | 
         
            +
                if isinstance(m, nn.Linear):
         
     | 
| 56 | 
         
            +
                    nn.init.trunc_normal_(m.weight, std=0.02)
         
     | 
| 57 | 
         
            +
                    nn.init.constant_(m.bias, 0)
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
            class ResidualUnit(nn.Module):
         
     | 
| 61 | 
         
            +
                def __init__(self, dim: int = 16, dilation: int = 1):
         
     | 
| 62 | 
         
            +
                    super().__init__()
         
     | 
| 63 | 
         
            +
                    pad = ((7 - 1) * dilation) // 2
         
     | 
| 64 | 
         
            +
                    self.block = nn.Sequential(
         
     | 
| 65 | 
         
            +
                        Snake1d(dim),
         
     | 
| 66 | 
         
            +
                        WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
         
     | 
| 67 | 
         
            +
                        Snake1d(dim),
         
     | 
| 68 | 
         
            +
                        WNConv1d(dim, dim, kernel_size=1),
         
     | 
| 69 | 
         
            +
                    )
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                def forward(self, x):
         
     | 
| 72 | 
         
            +
                    y = self.block(x)
         
     | 
| 73 | 
         
            +
                    pad = (x.shape[-1] - y.shape[-1]) // 2
         
     | 
| 74 | 
         
            +
                    if pad > 0:
         
     | 
| 75 | 
         
            +
                        x = x[..., pad:-pad]
         
     | 
| 76 | 
         
            +
                    return x + y
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
            class EncoderBlock(nn.Module):
         
     | 
| 80 | 
         
            +
                def __init__(self, dim: int = 16, stride: int = 1):
         
     | 
| 81 | 
         
            +
                    super().__init__()
         
     | 
| 82 | 
         
            +
                    self.block = nn.Sequential(
         
     | 
| 83 | 
         
            +
                        ResidualUnit(dim // 2, dilation=1),
         
     | 
| 84 | 
         
            +
                        ResidualUnit(dim // 2, dilation=3),
         
     | 
| 85 | 
         
            +
                        ResidualUnit(dim // 2, dilation=9),
         
     | 
| 86 | 
         
            +
                        Snake1d(dim // 2),
         
     | 
| 87 | 
         
            +
                        WNConv1d(
         
     | 
| 88 | 
         
            +
                            dim // 2,
         
     | 
| 89 | 
         
            +
                            dim,
         
     | 
| 90 | 
         
            +
                            kernel_size=2 * stride,
         
     | 
| 91 | 
         
            +
                            stride=stride,
         
     | 
| 92 | 
         
            +
                            padding=math.ceil(stride / 2),
         
     | 
| 93 | 
         
            +
                        ),
         
     | 
| 94 | 
         
            +
                    )
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
                def forward(self, x):
         
     | 
| 97 | 
         
            +
                    return self.block(x)
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
            class CodecEncoder(nn.Module):
         
     | 
| 101 | 
         
            +
                def __init__(
         
     | 
| 102 | 
         
            +
                    self,
         
     | 
| 103 | 
         
            +
                    d_model: int = 64,
         
     | 
| 104 | 
         
            +
                    up_ratios: list = [4, 5, 5, 6],
         
     | 
| 105 | 
         
            +
                    out_channels: int = 256,
         
     | 
| 106 | 
         
            +
                    use_tanh: bool = False,
         
     | 
| 107 | 
         
            +
                    cfg=None,
         
     | 
| 108 | 
         
            +
                ):
         
     | 
| 109 | 
         
            +
                    super().__init__()
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                    d_model = cfg.d_model if cfg is not None else d_model
         
     | 
| 112 | 
         
            +
                    up_ratios = cfg.up_ratios if cfg is not None else up_ratios
         
     | 
| 113 | 
         
            +
                    out_channels = cfg.out_channels if cfg is not None else out_channels
         
     | 
| 114 | 
         
            +
                    use_tanh = cfg.use_tanh if cfg is not None else use_tanh
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
                    # Create first convolution
         
     | 
| 117 | 
         
            +
                    self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                    # Create EncoderBlocks that double channels as they downsample by `stride`
         
     | 
| 120 | 
         
            +
                    for stride in up_ratios:
         
     | 
| 121 | 
         
            +
                        d_model *= 2
         
     | 
| 122 | 
         
            +
                        self.block += [EncoderBlock(d_model, stride=stride)]
         
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
                    # Create last convolution
         
     | 
| 125 | 
         
            +
                    self.block += [
         
     | 
| 126 | 
         
            +
                        Snake1d(d_model),
         
     | 
| 127 | 
         
            +
                        WNConv1d(d_model, out_channels, kernel_size=3, padding=1),
         
     | 
| 128 | 
         
            +
                    ]
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
                    if use_tanh:
         
     | 
| 131 | 
         
            +
                        self.block += [nn.Tanh()]
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                    # Wrap black into nn.Sequential
         
     | 
| 134 | 
         
            +
                    self.block = nn.Sequential(*self.block)
         
     | 
| 135 | 
         
            +
                    self.enc_dim = d_model
         
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
                    self.reset_parameters()
         
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
                def forward(self, x):
         
     | 
| 140 | 
         
            +
                    return self.block(x)
         
     | 
| 141 | 
         
            +
             
     | 
| 142 | 
         
            +
                def reset_parameters(self):
         
     | 
| 143 | 
         
            +
                    self.apply(init_weights)
         
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
            class DecoderBlock(nn.Module):
         
     | 
| 147 | 
         
            +
                def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1):
         
     | 
| 148 | 
         
            +
                    super().__init__()
         
     | 
| 149 | 
         
            +
                    self.block = nn.Sequential(
         
     | 
| 150 | 
         
            +
                        Snake1d(input_dim),
         
     | 
| 151 | 
         
            +
                        WNConvTranspose1d(
         
     | 
| 152 | 
         
            +
                            input_dim,
         
     | 
| 153 | 
         
            +
                            output_dim,
         
     | 
| 154 | 
         
            +
                            kernel_size=2 * stride,
         
     | 
| 155 | 
         
            +
                            stride=stride,
         
     | 
| 156 | 
         
            +
                            padding=stride // 2 + stride % 2,
         
     | 
| 157 | 
         
            +
                            output_padding=stride % 2,
         
     | 
| 158 | 
         
            +
                        ),
         
     | 
| 159 | 
         
            +
                        ResidualUnit(output_dim, dilation=1),
         
     | 
| 160 | 
         
            +
                        ResidualUnit(output_dim, dilation=3),
         
     | 
| 161 | 
         
            +
                        ResidualUnit(output_dim, dilation=9),
         
     | 
| 162 | 
         
            +
                    )
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
                def forward(self, x):
         
     | 
| 165 | 
         
            +
                    return self.block(x)
         
     | 
| 166 | 
         
            +
             
     | 
| 167 | 
         
            +
             
     | 
| 168 | 
         
            +
            class CodecDecoder(nn.Module):
         
     | 
| 169 | 
         
            +
                def __init__(
         
     | 
| 170 | 
         
            +
                    self,
         
     | 
| 171 | 
         
            +
                    in_channels: int = 256,
         
     | 
| 172 | 
         
            +
                    upsample_initial_channel: int = 1536,
         
     | 
| 173 | 
         
            +
                    up_ratios: list = [5, 5, 4, 2],
         
     | 
| 174 | 
         
            +
                    num_quantizers: int = 8,
         
     | 
| 175 | 
         
            +
                    codebook_size: int = 1024,
         
     | 
| 176 | 
         
            +
                    codebook_dim: int = 256,
         
     | 
| 177 | 
         
            +
                    quantizer_type: str = "vq",
         
     | 
| 178 | 
         
            +
                    quantizer_dropout: float = 0.5,
         
     | 
| 179 | 
         
            +
                    commitment: float = 0.25,
         
     | 
| 180 | 
         
            +
                    codebook_loss_weight: float = 1.0,
         
     | 
| 181 | 
         
            +
                    use_l2_normlize: bool = False,
         
     | 
| 182 | 
         
            +
                    codebook_type: str = "euclidean",
         
     | 
| 183 | 
         
            +
                    kmeans_init: bool = False,
         
     | 
| 184 | 
         
            +
                    kmeans_iters: int = 10,
         
     | 
| 185 | 
         
            +
                    decay: float = 0.8,
         
     | 
| 186 | 
         
            +
                    eps: float = 1e-5,
         
     | 
| 187 | 
         
            +
                    threshold_ema_dead_code: int = 2,
         
     | 
| 188 | 
         
            +
                    weight_init: bool = False,
         
     | 
| 189 | 
         
            +
                    use_vocos: bool = False,
         
     | 
| 190 | 
         
            +
                    vocos_dim: int = 384,
         
     | 
| 191 | 
         
            +
                    vocos_intermediate_dim: int = 1152,
         
     | 
| 192 | 
         
            +
                    vocos_num_layers: int = 8,
         
     | 
| 193 | 
         
            +
                    n_fft: int = 800,
         
     | 
| 194 | 
         
            +
                    hop_size: int = 200,
         
     | 
| 195 | 
         
            +
                    padding: str = "same",
         
     | 
| 196 | 
         
            +
                    cfg=None,
         
     | 
| 197 | 
         
            +
                ):
         
     | 
| 198 | 
         
            +
                    super().__init__()
         
     | 
| 199 | 
         
            +
             
     | 
| 200 | 
         
            +
                    in_channels = (
         
     | 
| 201 | 
         
            +
                        cfg.in_channels
         
     | 
| 202 | 
         
            +
                        if cfg is not None and hasattr(cfg, "in_channels")
         
     | 
| 203 | 
         
            +
                        else in_channels
         
     | 
| 204 | 
         
            +
                    )
         
     | 
| 205 | 
         
            +
                    upsample_initial_channel = (
         
     | 
| 206 | 
         
            +
                        cfg.upsample_initial_channel
         
     | 
| 207 | 
         
            +
                        if cfg is not None and hasattr(cfg, "upsample_initial_channel")
         
     | 
| 208 | 
         
            +
                        else upsample_initial_channel
         
     | 
| 209 | 
         
            +
                    )
         
     | 
| 210 | 
         
            +
                    up_ratios = (
         
     | 
| 211 | 
         
            +
                        cfg.up_ratios
         
     | 
| 212 | 
         
            +
                        if cfg is not None and hasattr(cfg, "up_ratios")
         
     | 
| 213 | 
         
            +
                        else up_ratios
         
     | 
| 214 | 
         
            +
                    )
         
     | 
| 215 | 
         
            +
                    num_quantizers = (
         
     | 
| 216 | 
         
            +
                        cfg.num_quantizers
         
     | 
| 217 | 
         
            +
                        if cfg is not None and hasattr(cfg, "num_quantizers")
         
     | 
| 218 | 
         
            +
                        else num_quantizers
         
     | 
| 219 | 
         
            +
                    )
         
     | 
| 220 | 
         
            +
                    codebook_size = (
         
     | 
| 221 | 
         
            +
                        cfg.codebook_size
         
     | 
| 222 | 
         
            +
                        if cfg is not None and hasattr(cfg, "codebook_size")
         
     | 
| 223 | 
         
            +
                        else codebook_size
         
     | 
| 224 | 
         
            +
                    )
         
     | 
| 225 | 
         
            +
                    codebook_dim = (
         
     | 
| 226 | 
         
            +
                        cfg.codebook_dim
         
     | 
| 227 | 
         
            +
                        if cfg is not None and hasattr(cfg, "codebook_dim")
         
     | 
| 228 | 
         
            +
                        else codebook_dim
         
     | 
| 229 | 
         
            +
                    )
         
     | 
| 230 | 
         
            +
                    quantizer_type = (
         
     | 
| 231 | 
         
            +
                        cfg.quantizer_type
         
     | 
| 232 | 
         
            +
                        if cfg is not None and hasattr(cfg, "quantizer_type")
         
     | 
| 233 | 
         
            +
                        else quantizer_type
         
     | 
| 234 | 
         
            +
                    )
         
     | 
| 235 | 
         
            +
                    quantizer_dropout = (
         
     | 
| 236 | 
         
            +
                        cfg.quantizer_dropout
         
     | 
| 237 | 
         
            +
                        if cfg is not None and hasattr(cfg, "quantizer_dropout")
         
     | 
| 238 | 
         
            +
                        else quantizer_dropout
         
     | 
| 239 | 
         
            +
                    )
         
     | 
| 240 | 
         
            +
                    commitment = (
         
     | 
| 241 | 
         
            +
                        cfg.commitment
         
     | 
| 242 | 
         
            +
                        if cfg is not None and hasattr(cfg, "commitment")
         
     | 
| 243 | 
         
            +
                        else commitment
         
     | 
| 244 | 
         
            +
                    )
         
     | 
| 245 | 
         
            +
                    codebook_loss_weight = (
         
     | 
| 246 | 
         
            +
                        cfg.codebook_loss_weight
         
     | 
| 247 | 
         
            +
                        if cfg is not None and hasattr(cfg, "codebook_loss_weight")
         
     | 
| 248 | 
         
            +
                        else codebook_loss_weight
         
     | 
| 249 | 
         
            +
                    )
         
     | 
| 250 | 
         
            +
                    use_l2_normlize = (
         
     | 
| 251 | 
         
            +
                        cfg.use_l2_normlize
         
     | 
| 252 | 
         
            +
                        if cfg is not None and hasattr(cfg, "use_l2_normlize")
         
     | 
| 253 | 
         
            +
                        else use_l2_normlize
         
     | 
| 254 | 
         
            +
                    )
         
     | 
| 255 | 
         
            +
                    codebook_type = (
         
     | 
| 256 | 
         
            +
                        cfg.codebook_type
         
     | 
| 257 | 
         
            +
                        if cfg is not None and hasattr(cfg, "codebook_type")
         
     | 
| 258 | 
         
            +
                        else codebook_type
         
     | 
| 259 | 
         
            +
                    )
         
     | 
| 260 | 
         
            +
                    kmeans_init = (
         
     | 
| 261 | 
         
            +
                        cfg.kmeans_init
         
     | 
| 262 | 
         
            +
                        if cfg is not None and hasattr(cfg, "kmeans_init")
         
     | 
| 263 | 
         
            +
                        else kmeans_init
         
     | 
| 264 | 
         
            +
                    )
         
     | 
| 265 | 
         
            +
                    kmeans_iters = (
         
     | 
| 266 | 
         
            +
                        cfg.kmeans_iters
         
     | 
| 267 | 
         
            +
                        if cfg is not None and hasattr(cfg, "kmeans_iters")
         
     | 
| 268 | 
         
            +
                        else kmeans_iters
         
     | 
| 269 | 
         
            +
                    )
         
     | 
| 270 | 
         
            +
                    decay = cfg.decay if cfg is not None and hasattr(cfg, "decay") else decay
         
     | 
| 271 | 
         
            +
                    eps = cfg.eps if cfg is not None and hasattr(cfg, "eps") else eps
         
     | 
| 272 | 
         
            +
                    threshold_ema_dead_code = (
         
     | 
| 273 | 
         
            +
                        cfg.threshold_ema_dead_code
         
     | 
| 274 | 
         
            +
                        if cfg is not None and hasattr(cfg, "threshold_ema_dead_code")
         
     | 
| 275 | 
         
            +
                        else threshold_ema_dead_code
         
     | 
| 276 | 
         
            +
                    )
         
     | 
| 277 | 
         
            +
                    weight_init = (
         
     | 
| 278 | 
         
            +
                        cfg.weight_init
         
     | 
| 279 | 
         
            +
                        if cfg is not None and hasattr(cfg, "weight_init")
         
     | 
| 280 | 
         
            +
                        else weight_init
         
     | 
| 281 | 
         
            +
                    )
         
     | 
| 282 | 
         
            +
                    use_vocos = (
         
     | 
| 283 | 
         
            +
                        cfg.use_vocos
         
     | 
| 284 | 
         
            +
                        if cfg is not None and hasattr(cfg, "use_vocos")
         
     | 
| 285 | 
         
            +
                        else use_vocos
         
     | 
| 286 | 
         
            +
                    )
         
     | 
| 287 | 
         
            +
                    vocos_dim = (
         
     | 
| 288 | 
         
            +
                        cfg.vocos_dim
         
     | 
| 289 | 
         
            +
                        if cfg is not None and hasattr(cfg, "vocos_dim")
         
     | 
| 290 | 
         
            +
                        else vocos_dim
         
     | 
| 291 | 
         
            +
                    )
         
     | 
| 292 | 
         
            +
                    vocos_intermediate_dim = (
         
     | 
| 293 | 
         
            +
                        cfg.vocos_intermediate_dim
         
     | 
| 294 | 
         
            +
                        if cfg is not None and hasattr(cfg, "vocos_intermediate_dim")
         
     | 
| 295 | 
         
            +
                        else vocos_intermediate_dim
         
     | 
| 296 | 
         
            +
                    )
         
     | 
| 297 | 
         
            +
                    vocos_num_layers = (
         
     | 
| 298 | 
         
            +
                        cfg.vocos_num_layers
         
     | 
| 299 | 
         
            +
                        if cfg is not None and hasattr(cfg, "vocos_num_layers")
         
     | 
| 300 | 
         
            +
                        else vocos_num_layers
         
     | 
| 301 | 
         
            +
                    )
         
     | 
| 302 | 
         
            +
                    n_fft = cfg.n_fft if cfg is not None and hasattr(cfg, "n_fft") else n_fft
         
     | 
| 303 | 
         
            +
                    hop_size = (
         
     | 
| 304 | 
         
            +
                        cfg.hop_size if cfg is not None and hasattr(cfg, "hop_size") else hop_size
         
     | 
| 305 | 
         
            +
                    )
         
     | 
| 306 | 
         
            +
                    padding = (
         
     | 
| 307 | 
         
            +
                        cfg.padding if cfg is not None and hasattr(cfg, "padding") else padding
         
     | 
| 308 | 
         
            +
                    )
         
     | 
| 309 | 
         
            +
             
     | 
| 310 | 
         
            +
                    if quantizer_type == "vq":
         
     | 
| 311 | 
         
            +
                        self.quantizer = ResidualVQ(
         
     | 
| 312 | 
         
            +
                            input_dim=in_channels,
         
     | 
| 313 | 
         
            +
                            num_quantizers=num_quantizers,
         
     | 
| 314 | 
         
            +
                            codebook_size=codebook_size,
         
     | 
| 315 | 
         
            +
                            codebook_dim=codebook_dim,
         
     | 
| 316 | 
         
            +
                            quantizer_type=quantizer_type,
         
     | 
| 317 | 
         
            +
                            quantizer_dropout=quantizer_dropout,
         
     | 
| 318 | 
         
            +
                            commitment=commitment,
         
     | 
| 319 | 
         
            +
                            codebook_loss_weight=codebook_loss_weight,
         
     | 
| 320 | 
         
            +
                            use_l2_normlize=use_l2_normlize,
         
     | 
| 321 | 
         
            +
                            codebook_type=codebook_type,
         
     | 
| 322 | 
         
            +
                            kmeans_init=kmeans_init,
         
     | 
| 323 | 
         
            +
                            kmeans_iters=kmeans_iters,
         
     | 
| 324 | 
         
            +
                            decay=decay,
         
     | 
| 325 | 
         
            +
                            eps=eps,
         
     | 
| 326 | 
         
            +
                            threshold_ema_dead_code=threshold_ema_dead_code,
         
     | 
| 327 | 
         
            +
                            weight_init=weight_init,
         
     | 
| 328 | 
         
            +
                        )
         
     | 
| 329 | 
         
            +
                    elif quantizer_type == "fvq":
         
     | 
| 330 | 
         
            +
                        self.quantizer = ResidualVQ(
         
     | 
| 331 | 
         
            +
                            input_dim=in_channels,
         
     | 
| 332 | 
         
            +
                            num_quantizers=num_quantizers,
         
     | 
| 333 | 
         
            +
                            codebook_size=codebook_size,
         
     | 
| 334 | 
         
            +
                            codebook_dim=codebook_dim,
         
     | 
| 335 | 
         
            +
                            quantizer_type=quantizer_type,
         
     | 
| 336 | 
         
            +
                            quantizer_dropout=quantizer_dropout,
         
     | 
| 337 | 
         
            +
                            commitment=commitment,
         
     | 
| 338 | 
         
            +
                            codebook_loss_weight=codebook_loss_weight,
         
     | 
| 339 | 
         
            +
                            use_l2_normlize=use_l2_normlize,
         
     | 
| 340 | 
         
            +
                        )
         
     | 
| 341 | 
         
            +
                    elif quantizer_type == "lfq":
         
     | 
| 342 | 
         
            +
                        self.quantizer = ResidualVQ(
         
     | 
| 343 | 
         
            +
                            input_dim=in_channels,
         
     | 
| 344 | 
         
            +
                            num_quantizers=num_quantizers,
         
     | 
| 345 | 
         
            +
                            codebook_size=codebook_size,
         
     | 
| 346 | 
         
            +
                            codebook_dim=codebook_dim,
         
     | 
| 347 | 
         
            +
                            quantizer_type=quantizer_type,
         
     | 
| 348 | 
         
            +
                        )
         
     | 
| 349 | 
         
            +
                    else:
         
     | 
| 350 | 
         
            +
                        raise ValueError(f"Unknown quantizer type {quantizer_type}")
         
     | 
| 351 | 
         
            +
             
     | 
| 352 | 
         
            +
                    if not use_vocos:
         
     | 
| 353 | 
         
            +
                        # Add first conv layer
         
     | 
| 354 | 
         
            +
                        channels = upsample_initial_channel
         
     | 
| 355 | 
         
            +
                        layers = [WNConv1d(in_channels, channels, kernel_size=7, padding=3)]
         
     | 
| 356 | 
         
            +
             
     | 
| 357 | 
         
            +
                        # Add upsampling + MRF blocks
         
     | 
| 358 | 
         
            +
                        for i, stride in enumerate(up_ratios):
         
     | 
| 359 | 
         
            +
                            input_dim = channels // 2**i
         
     | 
| 360 | 
         
            +
                            output_dim = channels // 2 ** (i + 1)
         
     | 
| 361 | 
         
            +
                            layers += [DecoderBlock(input_dim, output_dim, stride)]
         
     | 
| 362 | 
         
            +
             
     | 
| 363 | 
         
            +
                        # Add final conv layer
         
     | 
| 364 | 
         
            +
                        layers += [
         
     | 
| 365 | 
         
            +
                            Snake1d(output_dim),
         
     | 
| 366 | 
         
            +
                            WNConv1d(output_dim, 1, kernel_size=7, padding=3),
         
     | 
| 367 | 
         
            +
                            nn.Tanh(),
         
     | 
| 368 | 
         
            +
                        ]
         
     | 
| 369 | 
         
            +
             
     | 
| 370 | 
         
            +
                        self.model = nn.Sequential(*layers)
         
     | 
| 371 | 
         
            +
             
     | 
| 372 | 
         
            +
                    if use_vocos:
         
     | 
| 373 | 
         
            +
                        self.model = Vocos(
         
     | 
| 374 | 
         
            +
                            input_channels=in_channels,
         
     | 
| 375 | 
         
            +
                            dim=vocos_dim,
         
     | 
| 376 | 
         
            +
                            intermediate_dim=vocos_intermediate_dim,
         
     | 
| 377 | 
         
            +
                            num_layers=vocos_num_layers,
         
     | 
| 378 | 
         
            +
                            adanorm_num_embeddings=None,
         
     | 
| 379 | 
         
            +
                            n_fft=n_fft,
         
     | 
| 380 | 
         
            +
                            hop_size=hop_size,
         
     | 
| 381 | 
         
            +
                            padding=padding,
         
     | 
| 382 | 
         
            +
                        )
         
     | 
| 383 | 
         
            +
             
     | 
| 384 | 
         
            +
                    self.reset_parameters()
         
     | 
| 385 | 
         
            +
             
     | 
| 386 | 
         
            +
                def forward(self, x=None, vq=False, eval_vq=False, n_quantizers=None):
         
     | 
| 387 | 
         
            +
                    """
         
     | 
| 388 | 
         
            +
                    if vq is True, x = encoder output, then return quantized output;
         
     | 
| 389 | 
         
            +
                    else, x = quantized output, then return decoder output
         
     | 
| 390 | 
         
            +
                    """
         
     | 
| 391 | 
         
            +
                    if vq is True:
         
     | 
| 392 | 
         
            +
                        if eval_vq:
         
     | 
| 393 | 
         
            +
                            self.quantizer.eval()
         
     | 
| 394 | 
         
            +
                        (
         
     | 
| 395 | 
         
            +
                            quantized_out,
         
     | 
| 396 | 
         
            +
                            all_indices,
         
     | 
| 397 | 
         
            +
                            all_commit_losses,
         
     | 
| 398 | 
         
            +
                            all_codebook_losses,
         
     | 
| 399 | 
         
            +
                            all_quantized,
         
     | 
| 400 | 
         
            +
                        ) = self.quantizer(x, n_quantizers=n_quantizers)
         
     | 
| 401 | 
         
            +
                        return (
         
     | 
| 402 | 
         
            +
                            quantized_out,
         
     | 
| 403 | 
         
            +
                            all_indices,
         
     | 
| 404 | 
         
            +
                            all_commit_losses,
         
     | 
| 405 | 
         
            +
                            all_codebook_losses,
         
     | 
| 406 | 
         
            +
                            all_quantized,
         
     | 
| 407 | 
         
            +
                        )
         
     | 
| 408 | 
         
            +
             
     | 
| 409 | 
         
            +
                    return self.model(x)
         
     | 
| 410 | 
         
            +
             
     | 
| 411 | 
         
            +
                def quantize(self, x, n_quantizers=None):
         
     | 
| 412 | 
         
            +
                    self.quantizer.eval()
         
     | 
| 413 | 
         
            +
                    quantized_out, vq, _, _, _ = self.quantizer(x, n_quantizers=n_quantizers)
         
     | 
| 414 | 
         
            +
                    return quantized_out, vq
         
     | 
| 415 | 
         
            +
             
     | 
| 416 | 
         
            +
                # TODO: check consistency of vq2emb and quantize
         
     | 
| 417 | 
         
            +
                def vq2emb(self, vq, n_quantizers=None):
         
     | 
| 418 | 
         
            +
                    return self.quantizer.vq2emb(vq, n_quantizers=n_quantizers)
         
     | 
| 419 | 
         
            +
             
     | 
| 420 | 
         
            +
                def decode(self, x):
         
     | 
| 421 | 
         
            +
                    return self.model(x)
         
     | 
| 422 | 
         
            +
             
     | 
| 423 | 
         
            +
                def latent2dist(self, x, n_quantizers=None):
         
     | 
| 424 | 
         
            +
                    return self.quantizer.latent2dist(x, n_quantizers=n_quantizers)
         
     | 
| 425 | 
         
            +
             
     | 
| 426 | 
         
            +
                def reset_parameters(self):
         
     | 
| 427 | 
         
            +
                    self.apply(init_weights)
         
     | 
    	
        models/codec/amphion_codec/quantize/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,11 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2024 Amphion.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            from models.codec.amphion_codec.quantize.factorized_vector_quantize import (
         
     | 
| 7 | 
         
            +
                FactorizedVectorQuantize,
         
     | 
| 8 | 
         
            +
            )
         
     | 
| 9 | 
         
            +
            from models.codec.amphion_codec.quantize.vector_quantize import VectorQuantize
         
     | 
| 10 | 
         
            +
            from models.codec.amphion_codec.quantize.lookup_free_quantize import LookupFreeQuantize
         
     | 
| 11 | 
         
            +
            from models.codec.amphion_codec.quantize.residual_vq import ResidualVQ
         
     | 
    	
        models/codec/amphion_codec/quantize/factorized_vector_quantize.py
    ADDED
    
    | 
         @@ -0,0 +1,150 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2024 Amphion.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import numpy as np
         
     | 
| 7 | 
         
            +
            import torch
         
     | 
| 8 | 
         
            +
            import torch.nn as nn
         
     | 
| 9 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 10 | 
         
            +
            from einops import rearrange
         
     | 
| 11 | 
         
            +
            from torch.nn.utils import weight_norm
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            def WNConv1d(*args, **kwargs):
         
     | 
| 15 | 
         
            +
                return weight_norm(nn.Conv1d(*args, **kwargs))
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            def WNConvTranspose1d(*args, **kwargs):
         
     | 
| 19 | 
         
            +
                return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            class FactorizedVectorQuantize(nn.Module):
         
     | 
| 23 | 
         
            +
                def __init__(
         
     | 
| 24 | 
         
            +
                    self,
         
     | 
| 25 | 
         
            +
                    input_dim,
         
     | 
| 26 | 
         
            +
                    codebook_size,
         
     | 
| 27 | 
         
            +
                    codebook_dim,
         
     | 
| 28 | 
         
            +
                    commitment=0.005,
         
     | 
| 29 | 
         
            +
                    codebook_loss_weight=1.0,
         
     | 
| 30 | 
         
            +
                    use_l2_normlize=True,
         
     | 
| 31 | 
         
            +
                ):
         
     | 
| 32 | 
         
            +
                    super().__init__()
         
     | 
| 33 | 
         
            +
                    self.input_dim = input_dim
         
     | 
| 34 | 
         
            +
                    self.codebook_size = codebook_size
         
     | 
| 35 | 
         
            +
                    self.codebook_dim = codebook_dim
         
     | 
| 36 | 
         
            +
                    self.commitment = commitment
         
     | 
| 37 | 
         
            +
                    self.codebook_loss_weight = codebook_loss_weight
         
     | 
| 38 | 
         
            +
                    self.use_l2_normlize = use_l2_normlize
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                    if self.input_dim != self.codebook_dim:
         
     | 
| 41 | 
         
            +
                        self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1)
         
     | 
| 42 | 
         
            +
                        self.out_project = WNConv1d(
         
     | 
| 43 | 
         
            +
                            self.codebook_dim, self.input_dim, kernel_size=1
         
     | 
| 44 | 
         
            +
                        )
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                    else:
         
     | 
| 47 | 
         
            +
                        self.in_project = nn.Identity()
         
     | 
| 48 | 
         
            +
                        self.out_project = nn.Identity()
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                    self.codebook = nn.Embedding(self.codebook_size, self.codebook_dim)
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                def forward(self, z):
         
     | 
| 53 | 
         
            +
                    """
         
     | 
| 54 | 
         
            +
                    Parameters
         
     | 
| 55 | 
         
            +
                    ----------
         
     | 
| 56 | 
         
            +
                    z: torch.Tensor[B x D x T]
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                    Returns
         
     | 
| 59 | 
         
            +
                    -------
         
     | 
| 60 | 
         
            +
                    z_q: torch.Tensor[B x D x T]
         
     | 
| 61 | 
         
            +
                        Quantized continuous representation of input
         
     | 
| 62 | 
         
            +
                    commit_loss: Tensor[B]
         
     | 
| 63 | 
         
            +
                        Commitment loss to train encoder to predict vectors closer to codebook entries
         
     | 
| 64 | 
         
            +
                    codebook_loss: Tensor[B]
         
     | 
| 65 | 
         
            +
                        Codebook loss to update the codebook
         
     | 
| 66 | 
         
            +
                    indices: torch.Tensor[B x T]
         
     | 
| 67 | 
         
            +
                        Codebook indices (quantized discrete representation of input)
         
     | 
| 68 | 
         
            +
                    z_e: torch.Tensor[B x D x T]
         
     | 
| 69 | 
         
            +
                        Projected latents (continuous representation of input before quantization)
         
     | 
| 70 | 
         
            +
                    """
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                    # Factorized codes project input into low-dimensional space if self.input_dim != self.codebook_dim
         
     | 
| 73 | 
         
            +
                    z_e = self.in_project(z)
         
     | 
| 74 | 
         
            +
                    z_q, indices = self.decode_latents(z_e)
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                    # Compute commitment loss and codebook loss
         
     | 
| 77 | 
         
            +
                    if self.training:
         
     | 
| 78 | 
         
            +
                        commit_loss = (
         
     | 
| 79 | 
         
            +
                            F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
         
     | 
| 80 | 
         
            +
                            * self.commitment
         
     | 
| 81 | 
         
            +
                        )
         
     | 
| 82 | 
         
            +
                        codebook_loss = (
         
     | 
| 83 | 
         
            +
                            F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
         
     | 
| 84 | 
         
            +
                            * self.codebook_loss_weight
         
     | 
| 85 | 
         
            +
                        )
         
     | 
| 86 | 
         
            +
                    else:
         
     | 
| 87 | 
         
            +
                        commit_loss = torch.zeros(z.shape[0], device=z.device)
         
     | 
| 88 | 
         
            +
                        codebook_loss = torch.zeros(z.shape[0], device=z.device)
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                    z_q = z_e + (z_q - z_e).detach()
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                    z_q = self.out_project(z_q)
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                    return z_q, commit_loss, codebook_loss, indices, z_e
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
                def embed_code(self, embed_id):
         
     | 
| 97 | 
         
            +
                    return F.embedding(embed_id, self.codebook.weight)
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
                def decode_code(self, embed_id):
         
     | 
| 100 | 
         
            +
                    return self.embed_code(embed_id).transpose(1, 2)
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                def decode_latents(self, latents):
         
     | 
| 103 | 
         
            +
                    encodings = rearrange(latents, "b d t -> (b t) d")
         
     | 
| 104 | 
         
            +
                    codebook = self.codebook.weight
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
                    # L2 normalize encodings and codebook
         
     | 
| 107 | 
         
            +
                    if self.use_l2_normlize:
         
     | 
| 108 | 
         
            +
                        encodings = F.normalize(encodings)
         
     | 
| 109 | 
         
            +
                        codebook = F.normalize(codebook)
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                    # Compute euclidean distance between encodings and codebook,
         
     | 
| 112 | 
         
            +
                    # if use_l2_normlize is True, the distance is equal to cosine distance
         
     | 
| 113 | 
         
            +
                    dist = (
         
     | 
| 114 | 
         
            +
                        encodings.pow(2).sum(1, keepdim=True)
         
     | 
| 115 | 
         
            +
                        - 2 * encodings @ codebook.t()
         
     | 
| 116 | 
         
            +
                        + codebook.pow(2).sum(1, keepdim=True).t()
         
     | 
| 117 | 
         
            +
                    )
         
     | 
| 118 | 
         
            +
                    indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
         
     | 
| 119 | 
         
            +
                    z_q = self.decode_code(indices)
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
                    return z_q, indices
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
                def vq2emb(self, vq, out_proj=True):
         
     | 
| 124 | 
         
            +
                    emb = self.decode_code(vq)
         
     | 
| 125 | 
         
            +
                    if out_proj:
         
     | 
| 126 | 
         
            +
                        emb = self.out_project(emb)
         
     | 
| 127 | 
         
            +
                    return emb
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
                def latent2dist(self, latents):
         
     | 
| 130 | 
         
            +
                    encodings = rearrange(latents, "b d t -> (b t) d")
         
     | 
| 131 | 
         
            +
                    codebook = self.codebook.weight
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                    # L2 normalize encodings and codebook
         
     | 
| 134 | 
         
            +
                    if self.use_l2_normlize:
         
     | 
| 135 | 
         
            +
                        encodings = F.normalize(encodings)
         
     | 
| 136 | 
         
            +
                        codebook = F.normalize(codebook)
         
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
                    # Compute euclidean distance between encodings and codebook,
         
     | 
| 139 | 
         
            +
                    # if use_l2_normlize is True, the distance is equal to cosine distance
         
     | 
| 140 | 
         
            +
                    dist = (
         
     | 
| 141 | 
         
            +
                        encodings.pow(2).sum(1, keepdim=True)
         
     | 
| 142 | 
         
            +
                        - 2 * encodings @ codebook.t()
         
     | 
| 143 | 
         
            +
                        + codebook.pow(2).sum(1, keepdim=True).t()
         
     | 
| 144 | 
         
            +
                    )  # (b*t, k)
         
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
                    indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
         
     | 
| 147 | 
         
            +
                    dist = rearrange(dist, "(b t) k -> b t k", b=latents.size(0))
         
     | 
| 148 | 
         
            +
                    z_q = self.decode_code(indices)
         
     | 
| 149 | 
         
            +
             
     | 
| 150 | 
         
            +
                    return -dist, indices, z_q
         
     | 
    	
        models/codec/amphion_codec/quantize/lookup_free_quantize.py
    ADDED
    
    | 
         @@ -0,0 +1,77 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2024 Amphion.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import numpy as np
         
     | 
| 7 | 
         
            +
            import torch
         
     | 
| 8 | 
         
            +
            import torch.nn as nn
         
     | 
| 9 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 10 | 
         
            +
            from einops import rearrange
         
     | 
| 11 | 
         
            +
            from torch.nn.utils import weight_norm
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            def WNConv1d(*args, **kwargs):
         
     | 
| 15 | 
         
            +
                return weight_norm(nn.Conv1d(*args, **kwargs))
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            def WNConvTranspose1d(*args, **kwargs):
         
     | 
| 19 | 
         
            +
                return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            class LookupFreeQuantize(nn.Module):
         
     | 
| 23 | 
         
            +
                def __init__(
         
     | 
| 24 | 
         
            +
                    self,
         
     | 
| 25 | 
         
            +
                    input_dim,
         
     | 
| 26 | 
         
            +
                    codebook_size,
         
     | 
| 27 | 
         
            +
                    codebook_dim,
         
     | 
| 28 | 
         
            +
                ):
         
     | 
| 29 | 
         
            +
                    super().__init__()
         
     | 
| 30 | 
         
            +
                    self.input_dim = input_dim
         
     | 
| 31 | 
         
            +
                    self.codebook_size = codebook_size
         
     | 
| 32 | 
         
            +
                    self.codebook_dim = codebook_dim
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                    assert 2**codebook_dim == codebook_size
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
                    if self.input_dim != self.codebook_dim:
         
     | 
| 37 | 
         
            +
                        self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1)
         
     | 
| 38 | 
         
            +
                        self.out_project = WNConv1d(
         
     | 
| 39 | 
         
            +
                            self.codebook_dim, self.input_dim, kernel_size=1
         
     | 
| 40 | 
         
            +
                        )
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                    else:
         
     | 
| 43 | 
         
            +
                        self.in_project = nn.Identity()
         
     | 
| 44 | 
         
            +
                        self.out_project = nn.Identity()
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                def forward(self, z):
         
     | 
| 47 | 
         
            +
                    z_e = self.in_project(z)
         
     | 
| 48 | 
         
            +
                    z_e = F.sigmoid(z_e)
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                    z_q = z_e + (torch.round(z_e) - z_e).detach()
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                    z_q = self.out_project(z_q)
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                    commit_loss = torch.zeros(z.shape[0], device=z.device)
         
     | 
| 55 | 
         
            +
                    codebook_loss = torch.zeros(z.shape[0], device=z.device)
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                    bits = (
         
     | 
| 58 | 
         
            +
                        2
         
     | 
| 59 | 
         
            +
                        ** torch.arange(self.codebook_dim, device=z.device)
         
     | 
| 60 | 
         
            +
                        .unsqueeze(0)
         
     | 
| 61 | 
         
            +
                        .unsqueeze(-1)
         
     | 
| 62 | 
         
            +
                        .long()
         
     | 
| 63 | 
         
            +
                    )  # (1, d, 1)
         
     | 
| 64 | 
         
            +
                    indices = (torch.round(z_e.clone().detach()).long() * bits).sum(1).long()
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                    return z_q, commit_loss, codebook_loss, indices, z_e
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
                def vq2emb(self, vq, out_proj=True):
         
     | 
| 69 | 
         
            +
                    emb = torch.zeros(
         
     | 
| 70 | 
         
            +
                        vq.shape[0], self.codebook_dim, vq.shape[-1], device=vq.device
         
     | 
| 71 | 
         
            +
                    )  # (B, d, T)
         
     | 
| 72 | 
         
            +
                    for i in range(self.codebook_dim):
         
     | 
| 73 | 
         
            +
                        emb[:, i, :] = (vq % 2).float()
         
     | 
| 74 | 
         
            +
                        vq = vq // 2
         
     | 
| 75 | 
         
            +
                    if out_proj:
         
     | 
| 76 | 
         
            +
                        emb = self.out_project(emb)
         
     | 
| 77 | 
         
            +
                    return emb
         
     | 
    	
        models/codec/amphion_codec/quantize/residual_vq.py
    ADDED
    
    | 
         @@ -0,0 +1,177 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2024 Amphion.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            from typing import Union
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            import numpy as np
         
     | 
| 9 | 
         
            +
            import torch
         
     | 
| 10 | 
         
            +
            import torch.nn as nn
         
     | 
| 11 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 12 | 
         
            +
            from einops import rearrange
         
     | 
| 13 | 
         
            +
            from torch.nn.utils import weight_norm
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            from models.codec.amphion_codec.quantize.factorized_vector_quantize import (
         
     | 
| 16 | 
         
            +
                FactorizedVectorQuantize,
         
     | 
| 17 | 
         
            +
            )
         
     | 
| 18 | 
         
            +
            from models.codec.amphion_codec.quantize.vector_quantize import VectorQuantize
         
     | 
| 19 | 
         
            +
            from models.codec.amphion_codec.quantize.lookup_free_quantize import LookupFreeQuantize
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            class ResidualVQ(nn.Module):
         
     | 
| 23 | 
         
            +
                """
         
     | 
| 24 | 
         
            +
                Introduced in SoundStream: An end2end neural audio codec
         
     | 
| 25 | 
         
            +
                https://arxiv.org/abs/2107.03312
         
     | 
| 26 | 
         
            +
                """
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
                def __init__(
         
     | 
| 29 | 
         
            +
                    self,
         
     | 
| 30 | 
         
            +
                    input_dim: int = 256,
         
     | 
| 31 | 
         
            +
                    num_quantizers: int = 8,
         
     | 
| 32 | 
         
            +
                    codebook_size: int = 1024,
         
     | 
| 33 | 
         
            +
                    codebook_dim: int = 256,
         
     | 
| 34 | 
         
            +
                    quantizer_type: str = "vq",  # "vq" or "fvq" or "lfq"
         
     | 
| 35 | 
         
            +
                    quantizer_dropout: float = 0.5,
         
     | 
| 36 | 
         
            +
                    **kwargs,
         
     | 
| 37 | 
         
            +
                ):
         
     | 
| 38 | 
         
            +
                    super().__init__()
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                    self.input_dim = input_dim
         
     | 
| 41 | 
         
            +
                    self.num_quantizers = num_quantizers
         
     | 
| 42 | 
         
            +
                    self.codebook_size = codebook_size
         
     | 
| 43 | 
         
            +
                    self.codebook_dim = codebook_dim
         
     | 
| 44 | 
         
            +
                    self.quantizer_type = quantizer_type
         
     | 
| 45 | 
         
            +
                    self.quantizer_dropout = quantizer_dropout
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                    if quantizer_type == "vq":
         
     | 
| 48 | 
         
            +
                        VQ = VectorQuantize
         
     | 
| 49 | 
         
            +
                    elif quantizer_type == "fvq":
         
     | 
| 50 | 
         
            +
                        VQ = FactorizedVectorQuantize
         
     | 
| 51 | 
         
            +
                    elif quantizer_type == "lfq":
         
     | 
| 52 | 
         
            +
                        VQ = LookupFreeQuantize
         
     | 
| 53 | 
         
            +
                    else:
         
     | 
| 54 | 
         
            +
                        raise ValueError(f"Unknown quantizer type {quantizer_type}")
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                    self.quantizers = nn.ModuleList(
         
     | 
| 57 | 
         
            +
                        [
         
     | 
| 58 | 
         
            +
                            VQ(
         
     | 
| 59 | 
         
            +
                                input_dim=input_dim,
         
     | 
| 60 | 
         
            +
                                codebook_size=codebook_size,
         
     | 
| 61 | 
         
            +
                                codebook_dim=codebook_dim,
         
     | 
| 62 | 
         
            +
                                **kwargs,
         
     | 
| 63 | 
         
            +
                            )
         
     | 
| 64 | 
         
            +
                            for _ in range(num_quantizers)
         
     | 
| 65 | 
         
            +
                        ]
         
     | 
| 66 | 
         
            +
                    )
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
                def forward(self, z, n_quantizers: int = None):
         
     | 
| 69 | 
         
            +
                    """
         
     | 
| 70 | 
         
            +
                    Parameters
         
     | 
| 71 | 
         
            +
                    ----------
         
     | 
| 72 | 
         
            +
                    z : Tensor[B x D x T]
         
     | 
| 73 | 
         
            +
                    n_quantizers : int, optional
         
     | 
| 74 | 
         
            +
                        No. of quantizers to use
         
     | 
| 75 | 
         
            +
                        (n_quantizers < self.n_codebooks ex: for quantizer dropout)
         
     | 
| 76 | 
         
            +
                        Note: if `self.quantizer_dropout` is True, this argument is ignored
         
     | 
| 77 | 
         
            +
                            when in training mode, and a random number of quantizers is used.
         
     | 
| 78 | 
         
            +
                    Returns
         
     | 
| 79 | 
         
            +
                    -------
         
     | 
| 80 | 
         
            +
                    "quantized_out" : Tensor[B x D x T]
         
     | 
| 81 | 
         
            +
                        Quantized continuous representation of input
         
     | 
| 82 | 
         
            +
                    "all_indices" : Tensor[N x B x T]
         
     | 
| 83 | 
         
            +
                        Codebook indices for each codebook
         
     | 
| 84 | 
         
            +
                        (quantized discrete representation of input)
         
     | 
| 85 | 
         
            +
                    "all_commit_losses" : Tensor[N]
         
     | 
| 86 | 
         
            +
                    "all_codebook_losses" : Tensor[N]
         
     | 
| 87 | 
         
            +
                    "all_quantized" : Tensor[N x B x D x T]
         
     | 
| 88 | 
         
            +
                    """
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                    quantized_out = 0.0
         
     | 
| 91 | 
         
            +
                    residual = z
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
                    all_commit_losses = []
         
     | 
| 94 | 
         
            +
                    all_codebook_losses = []
         
     | 
| 95 | 
         
            +
                    all_indices = []
         
     | 
| 96 | 
         
            +
                    all_quantized = []
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
                    if n_quantizers is None:
         
     | 
| 99 | 
         
            +
                        n_quantizers = self.num_quantizers
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
                    if self.training:
         
     | 
| 102 | 
         
            +
                        n_quantizers = torch.ones((z.shape[0],)) * self.num_quantizers + 1
         
     | 
| 103 | 
         
            +
                        dropout = torch.randint(1, self.num_quantizers + 1, (z.shape[0],))
         
     | 
| 104 | 
         
            +
                        n_dropout = int(z.shape[0] * self.quantizer_dropout)
         
     | 
| 105 | 
         
            +
                        n_quantizers[:n_dropout] = dropout[:n_dropout]
         
     | 
| 106 | 
         
            +
                        n_quantizers = n_quantizers.to(z.device)
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                    for i, quantizer in enumerate(self.quantizers):
         
     | 
| 109 | 
         
            +
                        if self.training is False and i >= n_quantizers:
         
     | 
| 110 | 
         
            +
                            break
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                        z_q_i, commit_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
         
     | 
| 113 | 
         
            +
                            residual
         
     | 
| 114 | 
         
            +
                        )
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
                        # Create mask to apply quantizer dropout
         
     | 
| 117 | 
         
            +
                        mask = (
         
     | 
| 118 | 
         
            +
                            torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
         
     | 
| 119 | 
         
            +
                        )
         
     | 
| 120 | 
         
            +
                        quantized_out = quantized_out + z_q_i * mask[:, None, None]
         
     | 
| 121 | 
         
            +
                        residual = residual - z_q_i
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
                        commit_loss_i = (commit_loss_i * mask).mean()
         
     | 
| 124 | 
         
            +
                        codebook_loss_i = (codebook_loss_i * mask).mean()
         
     | 
| 125 | 
         
            +
             
     | 
| 126 | 
         
            +
                        all_commit_losses.append(commit_loss_i)
         
     | 
| 127 | 
         
            +
                        all_codebook_losses.append(codebook_loss_i)
         
     | 
| 128 | 
         
            +
                        all_indices.append(indices_i)
         
     | 
| 129 | 
         
            +
                        all_quantized.append(z_q_i)
         
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
                    all_commit_losses, all_codebook_losses, all_indices, all_quantized = map(
         
     | 
| 132 | 
         
            +
                        torch.stack,
         
     | 
| 133 | 
         
            +
                        (all_commit_losses, all_codebook_losses, all_indices, all_quantized),
         
     | 
| 134 | 
         
            +
                    )
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
                    return (
         
     | 
| 137 | 
         
            +
                        quantized_out,
         
     | 
| 138 | 
         
            +
                        all_indices,
         
     | 
| 139 | 
         
            +
                        all_commit_losses,
         
     | 
| 140 | 
         
            +
                        all_codebook_losses,
         
     | 
| 141 | 
         
            +
                        all_quantized,
         
     | 
| 142 | 
         
            +
                    )
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
                def vq2emb(self, vq, n_quantizers=None):
         
     | 
| 145 | 
         
            +
                    quantized_out = 0.0
         
     | 
| 146 | 
         
            +
                    if n_quantizers is None:
         
     | 
| 147 | 
         
            +
                        n_quantizers = self.num_quantizers
         
     | 
| 148 | 
         
            +
                    for idx, quantizer in enumerate(self.quantizers):
         
     | 
| 149 | 
         
            +
                        if idx >= n_quantizers:
         
     | 
| 150 | 
         
            +
                            break
         
     | 
| 151 | 
         
            +
                        quantized_out += quantizer.vq2emb(vq[idx])
         
     | 
| 152 | 
         
            +
                    return quantized_out
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
                def latent2dist(self, z, n_quantizers=None):
         
     | 
| 155 | 
         
            +
                    quantized_out = 0.0
         
     | 
| 156 | 
         
            +
                    residual = z
         
     | 
| 157 | 
         
            +
             
     | 
| 158 | 
         
            +
                    all_dists = []
         
     | 
| 159 | 
         
            +
                    all_indices = []
         
     | 
| 160 | 
         
            +
             
     | 
| 161 | 
         
            +
                    if n_quantizers is None:
         
     | 
| 162 | 
         
            +
                        n_quantizers = self.num_quantizers
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
                    for i, quantizer in enumerate(self.quantizers):
         
     | 
| 165 | 
         
            +
                        if self.training is False and i >= n_quantizers:
         
     | 
| 166 | 
         
            +
                            break
         
     | 
| 167 | 
         
            +
                        dist_i, indices_i, z_q_i = quantizer.latent2dist(residual)
         
     | 
| 168 | 
         
            +
                        all_dists.append(dist_i)
         
     | 
| 169 | 
         
            +
                        all_indices.append(indices_i)
         
     | 
| 170 | 
         
            +
             
     | 
| 171 | 
         
            +
                        quantized_out = quantized_out + z_q_i
         
     | 
| 172 | 
         
            +
                        residual = residual - z_q_i
         
     | 
| 173 | 
         
            +
             
     | 
| 174 | 
         
            +
                    all_dists = torch.stack(all_dists)
         
     | 
| 175 | 
         
            +
                    all_indices = torch.stack(all_indices)
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
                    return all_dists, all_indices
         
     | 
    	
        models/codec/amphion_codec/quantize/vector_quantize.py
    ADDED
    
    | 
         @@ -0,0 +1,401 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2024 Amphion.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import numpy as np
         
     | 
| 7 | 
         
            +
            import torch
         
     | 
| 8 | 
         
            +
            import torch.nn as nn
         
     | 
| 9 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 10 | 
         
            +
            from einops import rearrange, repeat
         
     | 
| 11 | 
         
            +
            from torch.nn.utils import weight_norm
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            def WNConv1d(*args, **kwargs):
         
     | 
| 15 | 
         
            +
                return weight_norm(nn.Conv1d(*args, **kwargs))
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            def WNConvTranspose1d(*args, **kwargs):
         
     | 
| 19 | 
         
            +
                return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            def l2norm(t):
         
     | 
| 23 | 
         
            +
                return F.normalize(t, p=2, dim=-1)
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            def ema_inplace(moving_avg, new, decay):
         
     | 
| 27 | 
         
            +
                moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
            def laplace_smoothing(x, n_categories, eps=1e-5):
         
     | 
| 31 | 
         
            +
                return (x + eps) / (x.sum() + n_categories * eps)
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
            def sample_vectors(samples, num):
         
     | 
| 35 | 
         
            +
                num_samples, device = samples.shape[0], samples.device
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                if num_samples >= num:
         
     | 
| 38 | 
         
            +
                    indices = torch.randperm(num_samples, device=device)[:num]
         
     | 
| 39 | 
         
            +
                else:
         
     | 
| 40 | 
         
            +
                    indices = torch.randint(0, num_samples, (num,), device=device)
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                return samples[indices]
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
            def kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False):
         
     | 
| 46 | 
         
            +
                dim, dtype, device = samples.shape[-1], samples.dtype, samples.device
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                means = sample_vectors(samples, num_clusters)
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                for _ in range(num_iters):
         
     | 
| 51 | 
         
            +
                    if use_cosine_sim:
         
     | 
| 52 | 
         
            +
                        dists = samples @ means.t()
         
     | 
| 53 | 
         
            +
                    else:
         
     | 
| 54 | 
         
            +
                        diffs = rearrange(samples, "n d -> n () d") - rearrange(
         
     | 
| 55 | 
         
            +
                            means, "c d -> () c d"
         
     | 
| 56 | 
         
            +
                        )
         
     | 
| 57 | 
         
            +
                        dists = -(diffs**2).sum(dim=-1)
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                    buckets = dists.max(dim=-1).indices
         
     | 
| 60 | 
         
            +
                    bins = torch.bincount(buckets, minlength=num_clusters)
         
     | 
| 61 | 
         
            +
                    zero_mask = bins == 0
         
     | 
| 62 | 
         
            +
                    bins_min_clamped = bins.masked_fill(zero_mask, 1)
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                    new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
         
     | 
| 65 | 
         
            +
                    new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
         
     | 
| 66 | 
         
            +
                    new_means = new_means / bins_min_clamped[..., None]
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
                    if use_cosine_sim:
         
     | 
| 69 | 
         
            +
                        new_means = l2norm(new_means)
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                    means = torch.where(zero_mask[..., None], means, new_means)
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                return means, bins
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
            class EuclideanCodebook(nn.Module):
         
     | 
| 77 | 
         
            +
                def __init__(
         
     | 
| 78 | 
         
            +
                    self,
         
     | 
| 79 | 
         
            +
                    dim,
         
     | 
| 80 | 
         
            +
                    codebook_size,
         
     | 
| 81 | 
         
            +
                    kmeans_init=False,
         
     | 
| 82 | 
         
            +
                    kmeans_iters=10,
         
     | 
| 83 | 
         
            +
                    decay=0.8,
         
     | 
| 84 | 
         
            +
                    eps=1e-5,
         
     | 
| 85 | 
         
            +
                    threshold_ema_dead_code=2,
         
     | 
| 86 | 
         
            +
                    weight_init=False,
         
     | 
| 87 | 
         
            +
                ):
         
     | 
| 88 | 
         
            +
                    super().__init__()
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                    self.decay = decay
         
     | 
| 91 | 
         
            +
                    init_fn = torch.randn if not weight_init else torch.zeros
         
     | 
| 92 | 
         
            +
                    embed = init_fn(codebook_size, dim)
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                    if weight_init:
         
     | 
| 95 | 
         
            +
                        nn.init.uniform_(embed, -1 / codebook_size, 1 / codebook_size)
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                    self.codebook_size = codebook_size
         
     | 
| 98 | 
         
            +
                    self.kmeans_iters = kmeans_iters
         
     | 
| 99 | 
         
            +
                    self.eps = eps
         
     | 
| 100 | 
         
            +
                    self.threshold_ema_dead_code = threshold_ema_dead_code
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                    self.register_buffer(
         
     | 
| 103 | 
         
            +
                        "initted", torch.Tensor([not kmeans_init])
         
     | 
| 104 | 
         
            +
                    )  # if kmeans_init is True, then initted is False; otherwise, initted is True
         
     | 
| 105 | 
         
            +
                    self.register_buffer("cluster_size", torch.zeros(codebook_size))
         
     | 
| 106 | 
         
            +
                    self.register_buffer("embed", embed)
         
     | 
| 107 | 
         
            +
                    self.register_buffer("embed_avg", embed.clone())
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
                def init_embed_(self, data):
         
     | 
| 110 | 
         
            +
                    embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
         
     | 
| 111 | 
         
            +
                    self.embed.data.copy_(embed)
         
     | 
| 112 | 
         
            +
                    self.embed_avg.data.copy_(embed)
         
     | 
| 113 | 
         
            +
                    self.cluster_size.data.copy_(cluster_size)
         
     | 
| 114 | 
         
            +
                    self.initted.data.copy_(torch.Tensor([True]))
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
                def replace(self, samples, mask):
         
     | 
| 117 | 
         
            +
                    modified_codebook = torch.where(
         
     | 
| 118 | 
         
            +
                        mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
         
     | 
| 119 | 
         
            +
                    )
         
     | 
| 120 | 
         
            +
                    self.embed.data.copy_(modified_codebook)
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
                def expire_codes_(self, batch_samples):
         
     | 
| 123 | 
         
            +
                    if self.threshold_ema_dead_code == 0:
         
     | 
| 124 | 
         
            +
                        return
         
     | 
| 125 | 
         
            +
             
     | 
| 126 | 
         
            +
                    expired_codes = self.cluster_size < self.threshold_ema_dead_code
         
     | 
| 127 | 
         
            +
                    if not torch.any(expired_codes):
         
     | 
| 128 | 
         
            +
                        return
         
     | 
| 129 | 
         
            +
                    batch_samples = rearrange(batch_samples, "... d -> (...) d")
         
     | 
| 130 | 
         
            +
                    self.replace(batch_samples, mask=expired_codes)
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
                def forward(self, x):
         
     | 
| 133 | 
         
            +
                    shape, dtype = x.shape, x.dtype
         
     | 
| 134 | 
         
            +
                    flatten = rearrange(x, "... d -> (...) d")
         
     | 
| 135 | 
         
            +
                    embed = self.embed.t()  # (codebook_size, dim) -> (dim, codebook_size)
         
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
                    if not self.initted:
         
     | 
| 138 | 
         
            +
                        self.init_embed_(flatten)
         
     | 
| 139 | 
         
            +
             
     | 
| 140 | 
         
            +
                    dist = -(
         
     | 
| 141 | 
         
            +
                        flatten.pow(2).sum(1, keepdim=True)
         
     | 
| 142 | 
         
            +
                        - 2 * flatten @ embed
         
     | 
| 143 | 
         
            +
                        + embed.pow(2).sum(0, keepdim=True)
         
     | 
| 144 | 
         
            +
                    )
         
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
                    embed_ind = dist.max(dim=-1).indices
         
     | 
| 147 | 
         
            +
                    embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
         
     | 
| 148 | 
         
            +
                    embed_ind = embed_ind.view(*shape[:-1])
         
     | 
| 149 | 
         
            +
                    quantize = F.embedding(embed_ind, self.embed)
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
                    if self.training:
         
     | 
| 152 | 
         
            +
                        ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
         
     | 
| 153 | 
         
            +
                        embed_sum = (
         
     | 
| 154 | 
         
            +
                            flatten.t() @ embed_onehot
         
     | 
| 155 | 
         
            +
                        )  # (dim, ...) @ (..., codebook_size) -> (dim, codebook_size)
         
     | 
| 156 | 
         
            +
                        ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
         
     | 
| 157 | 
         
            +
                        cluster_size = (
         
     | 
| 158 | 
         
            +
                            laplace_smoothing(self.cluster_size, self.codebook_size, self.eps)
         
     | 
| 159 | 
         
            +
                            * self.cluster_size.sum()
         
     | 
| 160 | 
         
            +
                        )
         
     | 
| 161 | 
         
            +
                        embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
         
     | 
| 162 | 
         
            +
                        self.embed.data.copy_(embed_normalized)
         
     | 
| 163 | 
         
            +
                        self.expire_codes_(x)
         
     | 
| 164 | 
         
            +
             
     | 
| 165 | 
         
            +
                    return quantize, embed_ind
         
     | 
| 166 | 
         
            +
             
     | 
| 167 | 
         
            +
                def vq2emb(self, vq):
         
     | 
| 168 | 
         
            +
                    quantize = F.embedding(vq, self.embed)
         
     | 
| 169 | 
         
            +
                    return quantize
         
     | 
| 170 | 
         
            +
             
     | 
| 171 | 
         
            +
                def latent2dist(self, x):
         
     | 
| 172 | 
         
            +
                    shape, dtype = x.shape, x.dtype
         
     | 
| 173 | 
         
            +
                    flatten = rearrange(x, "... d -> (...) d")
         
     | 
| 174 | 
         
            +
                    embed = self.embed.t()  # (codebook_size, dim) -> (dim, codebook_size)
         
     | 
| 175 | 
         
            +
             
     | 
| 176 | 
         
            +
                    if not self.initted:
         
     | 
| 177 | 
         
            +
                        self.init_embed_(flatten)
         
     | 
| 178 | 
         
            +
             
     | 
| 179 | 
         
            +
                    dist = -(
         
     | 
| 180 | 
         
            +
                        flatten.pow(2).sum(1, keepdim=True)
         
     | 
| 181 | 
         
            +
                        - 2 * flatten @ embed
         
     | 
| 182 | 
         
            +
                        + embed.pow(2).sum(0, keepdim=True)
         
     | 
| 183 | 
         
            +
                    )
         
     | 
| 184 | 
         
            +
             
     | 
| 185 | 
         
            +
                    embed_ind = dist.max(dim=-1).indices
         
     | 
| 186 | 
         
            +
                    embed_ind = embed_ind.view(*shape[:-1])
         
     | 
| 187 | 
         
            +
                    quantize = F.embedding(embed_ind, self.embed)
         
     | 
| 188 | 
         
            +
             
     | 
| 189 | 
         
            +
                    dist = dist.view(*shape[:-1], -1)
         
     | 
| 190 | 
         
            +
             
     | 
| 191 | 
         
            +
                    return dist, embed_ind, quantize
         
     | 
| 192 | 
         
            +
             
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
            class SimpleCodebook(nn.Module):
         
     | 
| 195 | 
         
            +
                def __init__(
         
     | 
| 196 | 
         
            +
                    self,
         
     | 
| 197 | 
         
            +
                    dim,
         
     | 
| 198 | 
         
            +
                    codebook_size,
         
     | 
| 199 | 
         
            +
                    use_l2_normlize=False,
         
     | 
| 200 | 
         
            +
                ):
         
     | 
| 201 | 
         
            +
                    super().__init__()
         
     | 
| 202 | 
         
            +
             
     | 
| 203 | 
         
            +
                    self.dim = dim
         
     | 
| 204 | 
         
            +
                    self.codebook_size = codebook_size
         
     | 
| 205 | 
         
            +
                    self.use_l2_normlize = use_l2_normlize
         
     | 
| 206 | 
         
            +
             
     | 
| 207 | 
         
            +
                    self.embed = nn.Embedding(self.codebook_size, self.dim)
         
     | 
| 208 | 
         
            +
             
     | 
| 209 | 
         
            +
                def forward(self, x):
         
     | 
| 210 | 
         
            +
                    shape, dtype = x.shape, x.dtype
         
     | 
| 211 | 
         
            +
                    flatten = rearrange(x, "... d -> (...) d")
         
     | 
| 212 | 
         
            +
                    embed = self.embed.weight.t()  # (codebook_size, dim) -> (dim, codebook_size)
         
     | 
| 213 | 
         
            +
             
     | 
| 214 | 
         
            +
                    if self.use_l2_normlize:
         
     | 
| 215 | 
         
            +
                        flatten = F.normalize(flatten)
         
     | 
| 216 | 
         
            +
                        embed = F.normalize(embed)
         
     | 
| 217 | 
         
            +
             
     | 
| 218 | 
         
            +
                    dist = -(
         
     | 
| 219 | 
         
            +
                        flatten.pow(2).sum(1, keepdim=True)
         
     | 
| 220 | 
         
            +
                        - 2 * flatten @ embed
         
     | 
| 221 | 
         
            +
                        + embed.pow(2).sum(0, keepdim=True)
         
     | 
| 222 | 
         
            +
                    )
         
     | 
| 223 | 
         
            +
             
     | 
| 224 | 
         
            +
                    embed_ind = dist.max(dim=-1).indices
         
     | 
| 225 | 
         
            +
                    embed_ind = embed_ind.view(*shape[:-1])
         
     | 
| 226 | 
         
            +
                    quantize = F.embedding(embed_ind, self.embed)
         
     | 
| 227 | 
         
            +
             
     | 
| 228 | 
         
            +
                    return quantize, embed_ind
         
     | 
| 229 | 
         
            +
             
     | 
| 230 | 
         
            +
                def vq2emb(self, vq):
         
     | 
| 231 | 
         
            +
                    quantize = F.embedding(vq, self.embed.weight)
         
     | 
| 232 | 
         
            +
                    return quantize
         
     | 
| 233 | 
         
            +
             
     | 
| 234 | 
         
            +
                def latent2dist(self, x):
         
     | 
| 235 | 
         
            +
                    shape, dtype = x.shape, x.dtype
         
     | 
| 236 | 
         
            +
                    flatten = rearrange(x, "... d -> (...) d")
         
     | 
| 237 | 
         
            +
                    embed = self.embed.weight.t()  # (codebook_size, dim) -> (dim, codebook_size)
         
     | 
| 238 | 
         
            +
             
     | 
| 239 | 
         
            +
                    if self.use_l2_normlize:
         
     | 
| 240 | 
         
            +
                        flatten = F.normalize(flatten)
         
     | 
| 241 | 
         
            +
                        embed = F.normalize(embed)
         
     | 
| 242 | 
         
            +
             
     | 
| 243 | 
         
            +
                    dist = -(
         
     | 
| 244 | 
         
            +
                        flatten.pow(2).sum(1, keepdim=True)
         
     | 
| 245 | 
         
            +
                        - 2 * flatten @ embed
         
     | 
| 246 | 
         
            +
                        + embed.pow(2).sum(0, keepdim=True)
         
     | 
| 247 | 
         
            +
                    )
         
     | 
| 248 | 
         
            +
             
     | 
| 249 | 
         
            +
                    embed_ind = dist.max(dim=-1).indices
         
     | 
| 250 | 
         
            +
                    embed_ind = embed_ind.view(*shape[:-1])
         
     | 
| 251 | 
         
            +
                    quantize = F.embedding(embed_ind, self.embed)
         
     | 
| 252 | 
         
            +
             
     | 
| 253 | 
         
            +
                    dist = dist.view(*shape[:-1], -1)
         
     | 
| 254 | 
         
            +
             
     | 
| 255 | 
         
            +
                    return dist, embed_ind, quantize
         
     | 
| 256 | 
         
            +
             
     | 
| 257 | 
         
            +
             
     | 
| 258 | 
         
            +
            class VectorQuantize(nn.Module):
         
     | 
| 259 | 
         
            +
                """Vector quantization and factorized vecotor quantization implementation
         
     | 
| 260 | 
         
            +
                Args:
         
     | 
| 261 | 
         
            +
                    input_dim (int): Dimension of input.
         
     | 
| 262 | 
         
            +
                    codebook_size (int): Codebook size.
         
     | 
| 263 | 
         
            +
                    codebook_dim (int): Codebook dimension. We suggest use codebook_dim = input_dim
         
     | 
| 264 | 
         
            +
                        if use codebook_type == "euclidean", otherwise, if you want to use
         
     | 
| 265 | 
         
            +
                        factorized vector quantization, use codebook_dim as small number (e.g. 8 or 32).
         
     | 
| 266 | 
         
            +
                    commitment (float): Weight for commitment loss.
         
     | 
| 267 | 
         
            +
                    use_l2_normlize (bool): Whether to use l2 normlized codes for factorized vecotor quantization,
         
     | 
| 268 | 
         
            +
                        we suggest use it as True if you want to use factorized vector quantization
         
     | 
| 269 | 
         
            +
                    kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
         
     | 
| 270 | 
         
            +
                    kmeans_iters (int): Number of iterations used for kmeans initialization.
         
     | 
| 271 | 
         
            +
                    decay (float): Decay for exponential moving average over the codebooks.
         
     | 
| 272 | 
         
            +
                    epsilon (float): Epsilon value for numerical stability.
         
     | 
| 273 | 
         
            +
                    threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
         
     | 
| 274 | 
         
            +
                        that have an exponential moving average cluster size less than the specified threshold with
         
     | 
| 275 | 
         
            +
                        randomly selected vector from the current batch.
         
     | 
| 276 | 
         
            +
                """
         
     | 
| 277 | 
         
            +
             
     | 
| 278 | 
         
            +
                def __init__(
         
     | 
| 279 | 
         
            +
                    self,
         
     | 
| 280 | 
         
            +
                    input_dim,
         
     | 
| 281 | 
         
            +
                    codebook_size,
         
     | 
| 282 | 
         
            +
                    codebook_dim,
         
     | 
| 283 | 
         
            +
                    commitment=0.005,
         
     | 
| 284 | 
         
            +
                    codebook_loss_weight=1.0,
         
     | 
| 285 | 
         
            +
                    use_l2_normlize=False,
         
     | 
| 286 | 
         
            +
                    codebook_type="euclidean",  # "euclidean" or "simple"
         
     | 
| 287 | 
         
            +
                    kmeans_init=False,
         
     | 
| 288 | 
         
            +
                    kmeans_iters=10,
         
     | 
| 289 | 
         
            +
                    decay=0.8,
         
     | 
| 290 | 
         
            +
                    eps=1e-5,
         
     | 
| 291 | 
         
            +
                    threshold_ema_dead_code=2,
         
     | 
| 292 | 
         
            +
                    weight_init=False,
         
     | 
| 293 | 
         
            +
                ):
         
     | 
| 294 | 
         
            +
                    super().__init__()
         
     | 
| 295 | 
         
            +
                    self.input_dim = input_dim
         
     | 
| 296 | 
         
            +
                    self.codebook_size = codebook_size
         
     | 
| 297 | 
         
            +
                    self.codebook_dim = codebook_dim
         
     | 
| 298 | 
         
            +
                    self.commitment = commitment
         
     | 
| 299 | 
         
            +
                    self.codebook_loss_weight = codebook_loss_weight
         
     | 
| 300 | 
         
            +
                    self.use_l2_normlize = use_l2_normlize
         
     | 
| 301 | 
         
            +
                    self.codebook_type = codebook_type
         
     | 
| 302 | 
         
            +
                    self.kmeans_init = kmeans_init
         
     | 
| 303 | 
         
            +
                    self.kmeans_iters = kmeans_iters
         
     | 
| 304 | 
         
            +
                    self.decay = decay
         
     | 
| 305 | 
         
            +
                    self.eps = eps
         
     | 
| 306 | 
         
            +
                    self.threshold_ema_dead_code = threshold_ema_dead_code
         
     | 
| 307 | 
         
            +
                    self.weight_init = weight_init
         
     | 
| 308 | 
         
            +
             
     | 
| 309 | 
         
            +
                    if self.input_dim != self.codebook_dim:
         
     | 
| 310 | 
         
            +
                        self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1)
         
     | 
| 311 | 
         
            +
                        self.out_project = WNConv1d(
         
     | 
| 312 | 
         
            +
                            self.codebook_dim, self.input_dim, kernel_size=1
         
     | 
| 313 | 
         
            +
                        )
         
     | 
| 314 | 
         
            +
             
     | 
| 315 | 
         
            +
                    else:
         
     | 
| 316 | 
         
            +
                        self.in_project = nn.Identity()
         
     | 
| 317 | 
         
            +
                        self.out_project = nn.Identity()
         
     | 
| 318 | 
         
            +
             
     | 
| 319 | 
         
            +
                    if self.codebook_type == "euclidean":
         
     | 
| 320 | 
         
            +
                        self.codebook = EuclideanCodebook(
         
     | 
| 321 | 
         
            +
                            self.codebook_dim,
         
     | 
| 322 | 
         
            +
                            codebook_size=self.codebook_size,
         
     | 
| 323 | 
         
            +
                            kmeans_init=self.kmeans_init,
         
     | 
| 324 | 
         
            +
                            kmeans_iters=self.kmeans_iters,
         
     | 
| 325 | 
         
            +
                            decay=self.decay,
         
     | 
| 326 | 
         
            +
                            eps=self.eps,
         
     | 
| 327 | 
         
            +
                            threshold_ema_dead_code=self.threshold_ema_dead_code,
         
     | 
| 328 | 
         
            +
                            weight_init=self.weight_init,
         
     | 
| 329 | 
         
            +
                        )
         
     | 
| 330 | 
         
            +
                    elif self.codebook_type == "simple":
         
     | 
| 331 | 
         
            +
                        self.codebook = SimpleCodebook(
         
     | 
| 332 | 
         
            +
                            self.codebook_dim,
         
     | 
| 333 | 
         
            +
                            codebook_size=self.codebook_size,
         
     | 
| 334 | 
         
            +
                            use_l2_normlize=self.use_l2_normlize,
         
     | 
| 335 | 
         
            +
                        )
         
     | 
| 336 | 
         
            +
                    else:
         
     | 
| 337 | 
         
            +
                        raise NotImplementedError(
         
     | 
| 338 | 
         
            +
                            f"codebook_type {self.codebook_type} is not implemented!"
         
     | 
| 339 | 
         
            +
                        )
         
     | 
| 340 | 
         
            +
             
     | 
| 341 | 
         
            +
                def forward(self, z):
         
     | 
| 342 | 
         
            +
                    """
         
     | 
| 343 | 
         
            +
                    Parameters
         
     | 
| 344 | 
         
            +
                    ----------
         
     | 
| 345 | 
         
            +
                    z: torch.Tensor[B x D x T]
         
     | 
| 346 | 
         
            +
             
     | 
| 347 | 
         
            +
                    Returns
         
     | 
| 348 | 
         
            +
                    -------
         
     | 
| 349 | 
         
            +
                    z_q: torch.Tensor[B x D x T]
         
     | 
| 350 | 
         
            +
                        Quantized continuous representation of input
         
     | 
| 351 | 
         
            +
                    commit_loss: Tensor[B]
         
     | 
| 352 | 
         
            +
                        Commitment loss to train encoder to predict vectors closer to codebook entries
         
     | 
| 353 | 
         
            +
                    codebook_loss: Tensor[B]
         
     | 
| 354 | 
         
            +
                        Codebook loss to update the codebook
         
     | 
| 355 | 
         
            +
                    indices: torch.Tensor[B x T]
         
     | 
| 356 | 
         
            +
                        Codebook indices (quantized discrete representation of input)
         
     | 
| 357 | 
         
            +
                    z_e: torch.Tensor[B x D x T]
         
     | 
| 358 | 
         
            +
                        Projected latents (continuous representation of input before quantization)
         
     | 
| 359 | 
         
            +
                    """
         
     | 
| 360 | 
         
            +
             
     | 
| 361 | 
         
            +
                    # Factorized codes project input into low-dimensional space if self.input_dim != self.codebook_dim
         
     | 
| 362 | 
         
            +
                    z_e = self.in_project(z)
         
     | 
| 363 | 
         
            +
                    z_q, indices = self.decode_latents(z_e)
         
     | 
| 364 | 
         
            +
             
     | 
| 365 | 
         
            +
                    # Compute commitment loss and codebook loss
         
     | 
| 366 | 
         
            +
                    if self.training:
         
     | 
| 367 | 
         
            +
                        commit_loss = (
         
     | 
| 368 | 
         
            +
                            F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
         
     | 
| 369 | 
         
            +
                            * self.commitment
         
     | 
| 370 | 
         
            +
                        )
         
     | 
| 371 | 
         
            +
                        codebook_loss = (
         
     | 
| 372 | 
         
            +
                            F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
         
     | 
| 373 | 
         
            +
                            * self.codebook_loss_weight
         
     | 
| 374 | 
         
            +
                        )
         
     | 
| 375 | 
         
            +
                    else:
         
     | 
| 376 | 
         
            +
                        commit_loss = torch.zeros(z.shape[0], device=z.device)
         
     | 
| 377 | 
         
            +
                        codebook_loss = torch.zeros(z.shape[0], device=z.device)
         
     | 
| 378 | 
         
            +
             
     | 
| 379 | 
         
            +
                    z_q = z_e + (z_q - z_e).detach()
         
     | 
| 380 | 
         
            +
             
     | 
| 381 | 
         
            +
                    z_q = self.out_project(z_q)
         
     | 
| 382 | 
         
            +
             
     | 
| 383 | 
         
            +
                    return z_q, commit_loss, codebook_loss, indices, z_e
         
     | 
| 384 | 
         
            +
             
     | 
| 385 | 
         
            +
                def decode_latents(self, latents):
         
     | 
| 386 | 
         
            +
                    encodings = rearrange(latents, "b d t -> b t d")
         
     | 
| 387 | 
         
            +
                    z_q, indices = self.codebook(encodings)
         
     | 
| 388 | 
         
            +
                    z_q = z_q.transpose(1, 2)
         
     | 
| 389 | 
         
            +
                    return z_q, indices
         
     | 
| 390 | 
         
            +
             
     | 
| 391 | 
         
            +
                def vq2emb(self, vq, out_proj=True):
         
     | 
| 392 | 
         
            +
                    emb = self.codebook.vq2emb(vq)
         
     | 
| 393 | 
         
            +
                    emb = emb.transpose(1, 2)
         
     | 
| 394 | 
         
            +
                    if out_proj:
         
     | 
| 395 | 
         
            +
                        emb = self.out_project(emb)
         
     | 
| 396 | 
         
            +
                    return emb
         
     | 
| 397 | 
         
            +
             
     | 
| 398 | 
         
            +
                def latent2dist(self, latents):
         
     | 
| 399 | 
         
            +
                    latents = rearrange(latents, "b d t -> b t d")
         
     | 
| 400 | 
         
            +
                    dist, embed_ind, quantize = self.codebook.latent2dist(latents)
         
     | 
| 401 | 
         
            +
                    return dist, embed_ind, quantize.transpose(1, 2)
         
     | 
    	
        models/codec/amphion_codec/vocos.py
    ADDED
    
    | 
         @@ -0,0 +1,881 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2024 Amphion.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            from typing import Optional, Tuple
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            import numpy as np
         
     | 
| 9 | 
         
            +
            import scipy
         
     | 
| 10 | 
         
            +
            import torch
         
     | 
| 11 | 
         
            +
            from torch import nn, view_as_real, view_as_complex
         
     | 
| 12 | 
         
            +
            from torch import nn
         
     | 
| 13 | 
         
            +
            from torch.nn.utils import weight_norm, remove_weight_norm
         
     | 
| 14 | 
         
            +
            from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz
         
     | 
| 15 | 
         
            +
            import librosa
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
         
     | 
| 19 | 
         
            +
                """
         
     | 
| 20 | 
         
            +
                Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
                Args:
         
     | 
| 23 | 
         
            +
                    x (Tensor): Input tensor.
         
     | 
| 24 | 
         
            +
                    clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
                Returns:
         
     | 
| 27 | 
         
            +
                    Tensor: Element-wise logarithm of the input tensor with clipping applied.
         
     | 
| 28 | 
         
            +
                """
         
     | 
| 29 | 
         
            +
                return torch.log(torch.clip(x, min=clip_val))
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
            def symlog(x: torch.Tensor) -> torch.Tensor:
         
     | 
| 33 | 
         
            +
                return torch.sign(x) * torch.log1p(x.abs())
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
            def symexp(x: torch.Tensor) -> torch.Tensor:
         
     | 
| 37 | 
         
            +
                return torch.sign(x) * (torch.exp(x.abs()) - 1)
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
            class STFT(nn.Module):
         
     | 
| 41 | 
         
            +
                def __init__(
         
     | 
| 42 | 
         
            +
                    self,
         
     | 
| 43 | 
         
            +
                    n_fft: int,
         
     | 
| 44 | 
         
            +
                    hop_length: int,
         
     | 
| 45 | 
         
            +
                    win_length: int,
         
     | 
| 46 | 
         
            +
                    center=True,
         
     | 
| 47 | 
         
            +
                ):
         
     | 
| 48 | 
         
            +
                    super().__init__()
         
     | 
| 49 | 
         
            +
                    self.center = center
         
     | 
| 50 | 
         
            +
                    self.n_fft = n_fft
         
     | 
| 51 | 
         
            +
                    self.hop_length = hop_length
         
     | 
| 52 | 
         
            +
                    self.win_length = win_length
         
     | 
| 53 | 
         
            +
                    window = torch.hann_window(win_length)
         
     | 
| 54 | 
         
            +
                    self.register_buffer("window", window)
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         
     | 
| 57 | 
         
            +
                    # x: (B, T * hop_length)
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                    if not self.center:
         
     | 
| 60 | 
         
            +
                        pad = self.win_length - self.hop_length
         
     | 
| 61 | 
         
            +
                        x = torch.nn.functional.pad(x, (pad // 2, pad // 2), mode="reflect")
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
                    stft_spec = torch.stft(
         
     | 
| 64 | 
         
            +
                        x,
         
     | 
| 65 | 
         
            +
                        self.n_fft,
         
     | 
| 66 | 
         
            +
                        hop_length=self.hop_length,
         
     | 
| 67 | 
         
            +
                        win_length=self.win_length,
         
     | 
| 68 | 
         
            +
                        window=self.window,
         
     | 
| 69 | 
         
            +
                        center=self.center,
         
     | 
| 70 | 
         
            +
                        return_complex=False,
         
     | 
| 71 | 
         
            +
                    )  # (B, n_fft // 2 + 1, T, 2)
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                    rea = stft_spec[:, :, :, 0]  # (B, n_fft // 2 + 1, T, 2)
         
     | 
| 74 | 
         
            +
                    imag = stft_spec[:, :, :, 1]  # (B, n_fft // 2 + 1, T, 2)
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                    log_mag = torch.log(
         
     | 
| 77 | 
         
            +
                        torch.abs(torch.sqrt(torch.pow(rea, 2) + torch.pow(imag, 2))) + 1e-5
         
     | 
| 78 | 
         
            +
                    )  # (B, n_fft // 2 + 1, T)
         
     | 
| 79 | 
         
            +
                    phase = torch.atan2(imag, rea)  # (B, n_fft // 2 + 1, T)
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                    return log_mag, phase
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
            class ISTFT(nn.Module):
         
     | 
| 85 | 
         
            +
                """
         
     | 
| 86 | 
         
            +
                Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
         
     | 
| 87 | 
         
            +
                windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
         
     | 
| 88 | 
         
            +
                See issue: https://github.com/pytorch/pytorch/issues/62323
         
     | 
| 89 | 
         
            +
                Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
         
     | 
| 90 | 
         
            +
                The NOLA constraint is met as we trim padded samples anyway.
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                Args:
         
     | 
| 93 | 
         
            +
                    n_fft (int): Size of Fourier transform.
         
     | 
| 94 | 
         
            +
                    hop_length (int): The distance between neighboring sliding window frames.
         
     | 
| 95 | 
         
            +
                    win_length (int): The size of window frame and STFT filter.
         
     | 
| 96 | 
         
            +
                    padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
         
     | 
| 97 | 
         
            +
                """
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
                def __init__(
         
     | 
| 100 | 
         
            +
                    self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"
         
     | 
| 101 | 
         
            +
                ):
         
     | 
| 102 | 
         
            +
                    super().__init__()
         
     | 
| 103 | 
         
            +
                    if padding not in ["center", "same"]:
         
     | 
| 104 | 
         
            +
                        raise ValueError("Padding must be 'center' or 'same'.")
         
     | 
| 105 | 
         
            +
                    self.padding = padding
         
     | 
| 106 | 
         
            +
                    self.n_fft = n_fft
         
     | 
| 107 | 
         
            +
                    self.hop_length = hop_length
         
     | 
| 108 | 
         
            +
                    self.win_length = win_length
         
     | 
| 109 | 
         
            +
                    window = torch.hann_window(win_length)
         
     | 
| 110 | 
         
            +
                    self.register_buffer("window", window)
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                def forward(self, spec: torch.Tensor) -> torch.Tensor:
         
     | 
| 113 | 
         
            +
                    """
         
     | 
| 114 | 
         
            +
                    Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
                    Args:
         
     | 
| 117 | 
         
            +
                        spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
         
     | 
| 118 | 
         
            +
                                        N is the number of frequency bins, and T is the number of time frames.
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
                    Returns:
         
     | 
| 121 | 
         
            +
                        Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
         
     | 
| 122 | 
         
            +
                    """
         
     | 
| 123 | 
         
            +
                    if self.padding == "center":
         
     | 
| 124 | 
         
            +
                        # Fallback to pytorch native implementation
         
     | 
| 125 | 
         
            +
                        return torch.istft(
         
     | 
| 126 | 
         
            +
                            spec,
         
     | 
| 127 | 
         
            +
                            self.n_fft,
         
     | 
| 128 | 
         
            +
                            self.hop_length,
         
     | 
| 129 | 
         
            +
                            self.win_length,
         
     | 
| 130 | 
         
            +
                            self.window,
         
     | 
| 131 | 
         
            +
                            center=True,
         
     | 
| 132 | 
         
            +
                        )
         
     | 
| 133 | 
         
            +
                    elif self.padding == "same":
         
     | 
| 134 | 
         
            +
                        pad = (self.win_length - self.hop_length) // 2
         
     | 
| 135 | 
         
            +
                    else:
         
     | 
| 136 | 
         
            +
                        raise ValueError("Padding must be 'center' or 'same'.")
         
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
                    assert spec.dim() == 3, "Expected a 3D tensor as input"
         
     | 
| 139 | 
         
            +
                    B, N, T = spec.shape
         
     | 
| 140 | 
         
            +
             
     | 
| 141 | 
         
            +
                    # Inverse FFT
         
     | 
| 142 | 
         
            +
                    ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
         
     | 
| 143 | 
         
            +
                    ifft = ifft * self.window[None, :, None]
         
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
                    # Overlap and Add
         
     | 
| 146 | 
         
            +
                    output_size = (T - 1) * self.hop_length + self.win_length
         
     | 
| 147 | 
         
            +
                    y = torch.nn.functional.fold(
         
     | 
| 148 | 
         
            +
                        ifft,
         
     | 
| 149 | 
         
            +
                        output_size=(1, output_size),
         
     | 
| 150 | 
         
            +
                        kernel_size=(1, self.win_length),
         
     | 
| 151 | 
         
            +
                        stride=(1, self.hop_length),
         
     | 
| 152 | 
         
            +
                    )[:, 0, 0, pad:-pad]
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
                    # Window envelope
         
     | 
| 155 | 
         
            +
                    window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
         
     | 
| 156 | 
         
            +
                    window_envelope = torch.nn.functional.fold(
         
     | 
| 157 | 
         
            +
                        window_sq,
         
     | 
| 158 | 
         
            +
                        output_size=(1, output_size),
         
     | 
| 159 | 
         
            +
                        kernel_size=(1, self.win_length),
         
     | 
| 160 | 
         
            +
                        stride=(1, self.hop_length),
         
     | 
| 161 | 
         
            +
                    ).squeeze()[pad:-pad]
         
     | 
| 162 | 
         
            +
             
     | 
| 163 | 
         
            +
                    # Normalize
         
     | 
| 164 | 
         
            +
                    assert (window_envelope > 1e-11).all()
         
     | 
| 165 | 
         
            +
                    y = y / window_envelope
         
     | 
| 166 | 
         
            +
             
     | 
| 167 | 
         
            +
                    return y
         
     | 
| 168 | 
         
            +
             
     | 
| 169 | 
         
            +
             
     | 
| 170 | 
         
            +
            class MDCT(nn.Module):
         
     | 
| 171 | 
         
            +
                """
         
     | 
| 172 | 
         
            +
                Modified Discrete Cosine Transform (MDCT) module.
         
     | 
| 173 | 
         
            +
             
     | 
| 174 | 
         
            +
                Args:
         
     | 
| 175 | 
         
            +
                    frame_len (int): Length of the MDCT frame.
         
     | 
| 176 | 
         
            +
                    padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
         
     | 
| 177 | 
         
            +
                """
         
     | 
| 178 | 
         
            +
             
     | 
| 179 | 
         
            +
                def __init__(self, frame_len: int, padding: str = "same"):
         
     | 
| 180 | 
         
            +
                    super().__init__()
         
     | 
| 181 | 
         
            +
                    if padding not in ["center", "same"]:
         
     | 
| 182 | 
         
            +
                        raise ValueError("Padding must be 'center' or 'same'.")
         
     | 
| 183 | 
         
            +
                    self.padding = padding
         
     | 
| 184 | 
         
            +
                    self.frame_len = frame_len
         
     | 
| 185 | 
         
            +
                    N = frame_len // 2
         
     | 
| 186 | 
         
            +
                    n0 = (N + 1) / 2
         
     | 
| 187 | 
         
            +
                    window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
         
     | 
| 188 | 
         
            +
                    self.register_buffer("window", window)
         
     | 
| 189 | 
         
            +
             
     | 
| 190 | 
         
            +
                    pre_twiddle = torch.exp(-1j * torch.pi * torch.arange(frame_len) / frame_len)
         
     | 
| 191 | 
         
            +
                    post_twiddle = torch.exp(-1j * torch.pi * n0 * (torch.arange(N) + 0.5) / N)
         
     | 
| 192 | 
         
            +
                    # view_as_real: NCCL Backend does not support ComplexFloat data type
         
     | 
| 193 | 
         
            +
                    # https://github.com/pytorch/pytorch/issues/71613
         
     | 
| 194 | 
         
            +
                    self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
         
     | 
| 195 | 
         
            +
                    self.register_buffer("post_twiddle", view_as_real(post_twiddle))
         
     | 
| 196 | 
         
            +
             
     | 
| 197 | 
         
            +
                def forward(self, audio: torch.Tensor) -> torch.Tensor:
         
     | 
| 198 | 
         
            +
                    """
         
     | 
| 199 | 
         
            +
                    Apply the Modified Discrete Cosine Transform (MDCT) to the input audio.
         
     | 
| 200 | 
         
            +
             
     | 
| 201 | 
         
            +
                    Args:
         
     | 
| 202 | 
         
            +
                        audio (Tensor): Input audio waveform of shape (B, T), where B is the batch size
         
     | 
| 203 | 
         
            +
                            and T is the length of the audio.
         
     | 
| 204 | 
         
            +
             
     | 
| 205 | 
         
            +
                    Returns:
         
     | 
| 206 | 
         
            +
                        Tensor: MDCT coefficients of shape (B, L, N), where L is the number of output frames
         
     | 
| 207 | 
         
            +
                            and N is the number of frequency bins.
         
     | 
| 208 | 
         
            +
                    """
         
     | 
| 209 | 
         
            +
                    if self.padding == "center":
         
     | 
| 210 | 
         
            +
                        audio = torch.nn.functional.pad(
         
     | 
| 211 | 
         
            +
                            audio, (self.frame_len // 2, self.frame_len // 2)
         
     | 
| 212 | 
         
            +
                        )
         
     | 
| 213 | 
         
            +
                    elif self.padding == "same":
         
     | 
| 214 | 
         
            +
                        # hop_length is 1/2 frame_len
         
     | 
| 215 | 
         
            +
                        audio = torch.nn.functional.pad(
         
     | 
| 216 | 
         
            +
                            audio, (self.frame_len // 4, self.frame_len // 4)
         
     | 
| 217 | 
         
            +
                        )
         
     | 
| 218 | 
         
            +
                    else:
         
     | 
| 219 | 
         
            +
                        raise ValueError("Padding must be 'center' or 'same'.")
         
     | 
| 220 | 
         
            +
             
     | 
| 221 | 
         
            +
                    x = audio.unfold(-1, self.frame_len, self.frame_len // 2)
         
     | 
| 222 | 
         
            +
                    N = self.frame_len // 2
         
     | 
| 223 | 
         
            +
                    x = x * self.window.expand(x.shape)
         
     | 
| 224 | 
         
            +
                    X = torch.fft.fft(
         
     | 
| 225 | 
         
            +
                        x * view_as_complex(self.pre_twiddle).expand(x.shape), dim=-1
         
     | 
| 226 | 
         
            +
                    )[..., :N]
         
     | 
| 227 | 
         
            +
                    res = X * view_as_complex(self.post_twiddle).expand(X.shape) * np.sqrt(1 / N)
         
     | 
| 228 | 
         
            +
                    return torch.real(res) * np.sqrt(2)
         
     | 
| 229 | 
         
            +
             
     | 
| 230 | 
         
            +
             
     | 
| 231 | 
         
            +
            class IMDCT(nn.Module):
         
     | 
| 232 | 
         
            +
                """
         
     | 
| 233 | 
         
            +
                Inverse Modified Discrete Cosine Transform (IMDCT) module.
         
     | 
| 234 | 
         
            +
             
     | 
| 235 | 
         
            +
                Args:
         
     | 
| 236 | 
         
            +
                    frame_len (int): Length of the MDCT frame.
         
     | 
| 237 | 
         
            +
                    padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
         
     | 
| 238 | 
         
            +
                """
         
     | 
| 239 | 
         
            +
             
     | 
| 240 | 
         
            +
                def __init__(self, frame_len: int, padding: str = "same"):
         
     | 
| 241 | 
         
            +
                    super().__init__()
         
     | 
| 242 | 
         
            +
                    if padding not in ["center", "same"]:
         
     | 
| 243 | 
         
            +
                        raise ValueError("Padding must be 'center' or 'same'.")
         
     | 
| 244 | 
         
            +
                    self.padding = padding
         
     | 
| 245 | 
         
            +
                    self.frame_len = frame_len
         
     | 
| 246 | 
         
            +
                    N = frame_len // 2
         
     | 
| 247 | 
         
            +
                    n0 = (N + 1) / 2
         
     | 
| 248 | 
         
            +
                    window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
         
     | 
| 249 | 
         
            +
                    self.register_buffer("window", window)
         
     | 
| 250 | 
         
            +
             
     | 
| 251 | 
         
            +
                    pre_twiddle = torch.exp(1j * torch.pi * n0 * torch.arange(N * 2) / N)
         
     | 
| 252 | 
         
            +
                    post_twiddle = torch.exp(1j * torch.pi * (torch.arange(N * 2) + n0) / (N * 2))
         
     | 
| 253 | 
         
            +
                    self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
         
     | 
| 254 | 
         
            +
                    self.register_buffer("post_twiddle", view_as_real(post_twiddle))
         
     | 
| 255 | 
         
            +
             
     | 
| 256 | 
         
            +
                def forward(self, X: torch.Tensor) -> torch.Tensor:
         
     | 
| 257 | 
         
            +
                    """
         
     | 
| 258 | 
         
            +
                    Apply the Inverse Modified Discrete Cosine Transform (IMDCT) to the input MDCT coefficients.
         
     | 
| 259 | 
         
            +
             
     | 
| 260 | 
         
            +
                    Args:
         
     | 
| 261 | 
         
            +
                        X (Tensor): Input MDCT coefficients of shape (B, L, N), where B is the batch size,
         
     | 
| 262 | 
         
            +
                            L is the number of frames, and N is the number of frequency bins.
         
     | 
| 263 | 
         
            +
             
     | 
| 264 | 
         
            +
                    Returns:
         
     | 
| 265 | 
         
            +
                        Tensor: Reconstructed audio waveform of shape (B, T), where T is the length of the audio.
         
     | 
| 266 | 
         
            +
                    """
         
     | 
| 267 | 
         
            +
                    B, L, N = X.shape
         
     | 
| 268 | 
         
            +
                    Y = torch.zeros((B, L, N * 2), dtype=X.dtype, device=X.device)
         
     | 
| 269 | 
         
            +
                    Y[..., :N] = X
         
     | 
| 270 | 
         
            +
                    Y[..., N:] = -1 * torch.conj(torch.flip(X, dims=(-1,)))
         
     | 
| 271 | 
         
            +
                    y = torch.fft.ifft(
         
     | 
| 272 | 
         
            +
                        Y * view_as_complex(self.pre_twiddle).expand(Y.shape), dim=-1
         
     | 
| 273 | 
         
            +
                    )
         
     | 
| 274 | 
         
            +
                    y = (
         
     | 
| 275 | 
         
            +
                        torch.real(y * view_as_complex(self.post_twiddle).expand(y.shape))
         
     | 
| 276 | 
         
            +
                        * np.sqrt(N)
         
     | 
| 277 | 
         
            +
                        * np.sqrt(2)
         
     | 
| 278 | 
         
            +
                    )
         
     | 
| 279 | 
         
            +
                    result = y * self.window.expand(y.shape)
         
     | 
| 280 | 
         
            +
                    output_size = (1, (L + 1) * N)
         
     | 
| 281 | 
         
            +
                    audio = torch.nn.functional.fold(
         
     | 
| 282 | 
         
            +
                        result.transpose(1, 2),
         
     | 
| 283 | 
         
            +
                        output_size=output_size,
         
     | 
| 284 | 
         
            +
                        kernel_size=(1, self.frame_len),
         
     | 
| 285 | 
         
            +
                        stride=(1, self.frame_len // 2),
         
     | 
| 286 | 
         
            +
                    )[:, 0, 0, :]
         
     | 
| 287 | 
         
            +
             
     | 
| 288 | 
         
            +
                    if self.padding == "center":
         
     | 
| 289 | 
         
            +
                        pad = self.frame_len // 2
         
     | 
| 290 | 
         
            +
                    elif self.padding == "same":
         
     | 
| 291 | 
         
            +
                        pad = self.frame_len // 4
         
     | 
| 292 | 
         
            +
                    else:
         
     | 
| 293 | 
         
            +
                        raise ValueError("Padding must be 'center' or 'same'.")
         
     | 
| 294 | 
         
            +
             
     | 
| 295 | 
         
            +
                    audio = audio[:, pad:-pad]
         
     | 
| 296 | 
         
            +
                    return audio
         
     | 
| 297 | 
         
            +
             
     | 
| 298 | 
         
            +
             
     | 
| 299 | 
         
            +
            class FourierHead(nn.Module):
         
     | 
| 300 | 
         
            +
                """Base class for inverse fourier modules."""
         
     | 
| 301 | 
         
            +
             
     | 
| 302 | 
         
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         
     | 
| 303 | 
         
            +
                    """
         
     | 
| 304 | 
         
            +
                    Args:
         
     | 
| 305 | 
         
            +
                        x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
         
     | 
| 306 | 
         
            +
                                    L is the sequence length, and H denotes the model dimension.
         
     | 
| 307 | 
         
            +
             
     | 
| 308 | 
         
            +
                    Returns:
         
     | 
| 309 | 
         
            +
                        Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
         
     | 
| 310 | 
         
            +
                    """
         
     | 
| 311 | 
         
            +
                    raise NotImplementedError("Subclasses must implement the forward method.")
         
     | 
| 312 | 
         
            +
             
     | 
| 313 | 
         
            +
             
     | 
| 314 | 
         
            +
            class ISTFTHead(FourierHead):
         
     | 
| 315 | 
         
            +
                """
         
     | 
| 316 | 
         
            +
                ISTFT Head module for predicting STFT complex coefficients.
         
     | 
| 317 | 
         
            +
             
     | 
| 318 | 
         
            +
                Args:
         
     | 
| 319 | 
         
            +
                    dim (int): Hidden dimension of the model.
         
     | 
| 320 | 
         
            +
                    n_fft (int): Size of Fourier transform.
         
     | 
| 321 | 
         
            +
                    hop_length (int): The distance between neighboring sliding window frames, which should align with
         
     | 
| 322 | 
         
            +
                                      the resolution of the input features.
         
     | 
| 323 | 
         
            +
                    padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
         
     | 
| 324 | 
         
            +
                """
         
     | 
| 325 | 
         
            +
             
     | 
| 326 | 
         
            +
                def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
         
     | 
| 327 | 
         
            +
                    super().__init__()
         
     | 
| 328 | 
         
            +
                    out_dim = n_fft + 2
         
     | 
| 329 | 
         
            +
                    self.out = torch.nn.Linear(dim, out_dim)
         
     | 
| 330 | 
         
            +
                    self.istft = ISTFT(
         
     | 
| 331 | 
         
            +
                        n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding
         
     | 
| 332 | 
         
            +
                    )
         
     | 
| 333 | 
         
            +
             
     | 
| 334 | 
         
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         
     | 
| 335 | 
         
            +
                    """
         
     | 
| 336 | 
         
            +
                    Forward pass of the ISTFTHead module.
         
     | 
| 337 | 
         
            +
             
     | 
| 338 | 
         
            +
                    Args:
         
     | 
| 339 | 
         
            +
                        x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
         
     | 
| 340 | 
         
            +
                                    L is the sequence length, and H denotes the model dimension.
         
     | 
| 341 | 
         
            +
             
     | 
| 342 | 
         
            +
                    Returns:
         
     | 
| 343 | 
         
            +
                        Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
         
     | 
| 344 | 
         
            +
                    """
         
     | 
| 345 | 
         
            +
                    x = self.out(x).transpose(1, 2)
         
     | 
| 346 | 
         
            +
                    mag, p = x.chunk(2, dim=1)
         
     | 
| 347 | 
         
            +
                    mag = torch.exp(mag)
         
     | 
| 348 | 
         
            +
                    mag = torch.clip(
         
     | 
| 349 | 
         
            +
                        mag, max=1e2
         
     | 
| 350 | 
         
            +
                    )  # safeguard to prevent excessively large magnitudes
         
     | 
| 351 | 
         
            +
                    # wrapping happens here. These two lines produce real and imaginary value
         
     | 
| 352 | 
         
            +
                    x = torch.cos(p)
         
     | 
| 353 | 
         
            +
                    y = torch.sin(p)
         
     | 
| 354 | 
         
            +
                    # recalculating phase here does not produce anything new
         
     | 
| 355 | 
         
            +
                    # only costs time
         
     | 
| 356 | 
         
            +
                    # phase = torch.atan2(y, x)
         
     | 
| 357 | 
         
            +
                    # S = mag * torch.exp(phase * 1j)
         
     | 
| 358 | 
         
            +
                    # better directly produce the complex value
         
     | 
| 359 | 
         
            +
                    S = mag * (x + 1j * y)
         
     | 
| 360 | 
         
            +
                    audio = self.istft(S)
         
     | 
| 361 | 
         
            +
                    return audio
         
     | 
| 362 | 
         
            +
             
     | 
| 363 | 
         
            +
             
     | 
| 364 | 
         
            +
            class IMDCTSymExpHead(FourierHead):
         
     | 
| 365 | 
         
            +
                """
         
     | 
| 366 | 
         
            +
                IMDCT Head module for predicting MDCT coefficients with symmetric exponential function
         
     | 
| 367 | 
         
            +
             
     | 
| 368 | 
         
            +
                Args:
         
     | 
| 369 | 
         
            +
                    dim (int): Hidden dimension of the model.
         
     | 
| 370 | 
         
            +
                    mdct_frame_len (int): Length of the MDCT frame.
         
     | 
| 371 | 
         
            +
                    padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
         
     | 
| 372 | 
         
            +
                    sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized
         
     | 
| 373 | 
         
            +
                                                 based on perceptual scaling. Defaults to None.
         
     | 
| 374 | 
         
            +
                    clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
         
     | 
| 375 | 
         
            +
                """
         
     | 
| 376 | 
         
            +
             
     | 
| 377 | 
         
            +
                def __init__(
         
     | 
| 378 | 
         
            +
                    self,
         
     | 
| 379 | 
         
            +
                    dim: int,
         
     | 
| 380 | 
         
            +
                    mdct_frame_len: int,
         
     | 
| 381 | 
         
            +
                    padding: str = "same",
         
     | 
| 382 | 
         
            +
                    sample_rate: Optional[int] = None,
         
     | 
| 383 | 
         
            +
                    clip_audio: bool = False,
         
     | 
| 384 | 
         
            +
                ):
         
     | 
| 385 | 
         
            +
                    super().__init__()
         
     | 
| 386 | 
         
            +
                    out_dim = mdct_frame_len // 2
         
     | 
| 387 | 
         
            +
                    self.out = nn.Linear(dim, out_dim)
         
     | 
| 388 | 
         
            +
                    self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
         
     | 
| 389 | 
         
            +
                    self.clip_audio = clip_audio
         
     | 
| 390 | 
         
            +
             
     | 
| 391 | 
         
            +
                    if sample_rate is not None:
         
     | 
| 392 | 
         
            +
                        # optionally init the last layer following mel-scale
         
     | 
| 393 | 
         
            +
                        m_max = _hz_to_mel(sample_rate // 2)
         
     | 
| 394 | 
         
            +
                        m_pts = torch.linspace(0, m_max, out_dim)
         
     | 
| 395 | 
         
            +
                        f_pts = _mel_to_hz(m_pts)
         
     | 
| 396 | 
         
            +
                        scale = 1 - (f_pts / f_pts.max())
         
     | 
| 397 | 
         
            +
             
     | 
| 398 | 
         
            +
                        with torch.no_grad():
         
     | 
| 399 | 
         
            +
                            self.out.weight.mul_(scale.view(-1, 1))
         
     | 
| 400 | 
         
            +
             
     | 
| 401 | 
         
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         
     | 
| 402 | 
         
            +
                    """
         
     | 
| 403 | 
         
            +
                    Forward pass of the IMDCTSymExpHead module.
         
     | 
| 404 | 
         
            +
             
     | 
| 405 | 
         
            +
                    Args:
         
     | 
| 406 | 
         
            +
                        x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
         
     | 
| 407 | 
         
            +
                                    L is the sequence length, and H denotes the model dimension.
         
     | 
| 408 | 
         
            +
             
     | 
| 409 | 
         
            +
                    Returns:
         
     | 
| 410 | 
         
            +
                        Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
         
     | 
| 411 | 
         
            +
                    """
         
     | 
| 412 | 
         
            +
                    x = self.out(x)
         
     | 
| 413 | 
         
            +
                    x = symexp(x)
         
     | 
| 414 | 
         
            +
                    x = torch.clip(
         
     | 
| 415 | 
         
            +
                        x, min=-1e2, max=1e2
         
     | 
| 416 | 
         
            +
                    )  # safeguard to prevent excessively large magnitudes
         
     | 
| 417 | 
         
            +
                    audio = self.imdct(x)
         
     | 
| 418 | 
         
            +
                    if self.clip_audio:
         
     | 
| 419 | 
         
            +
                        audio = torch.clip(x, min=-1.0, max=1.0)
         
     | 
| 420 | 
         
            +
             
     | 
| 421 | 
         
            +
                    return audio
         
     | 
| 422 | 
         
            +
             
     | 
| 423 | 
         
            +
             
     | 
| 424 | 
         
            +
            class IMDCTCosHead(FourierHead):
         
     | 
| 425 | 
         
            +
                """
         
     | 
| 426 | 
         
            +
                IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) · cos(p)
         
     | 
| 427 | 
         
            +
             
     | 
| 428 | 
         
            +
                Args:
         
     | 
| 429 | 
         
            +
                    dim (int): Hidden dimension of the model.
         
     | 
| 430 | 
         
            +
                    mdct_frame_len (int): Length of the MDCT frame.
         
     | 
| 431 | 
         
            +
                    padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
         
     | 
| 432 | 
         
            +
                    clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
         
     | 
| 433 | 
         
            +
                """
         
     | 
| 434 | 
         
            +
             
     | 
| 435 | 
         
            +
                def __init__(
         
     | 
| 436 | 
         
            +
                    self,
         
     | 
| 437 | 
         
            +
                    dim: int,
         
     | 
| 438 | 
         
            +
                    mdct_frame_len: int,
         
     | 
| 439 | 
         
            +
                    padding: str = "same",
         
     | 
| 440 | 
         
            +
                    clip_audio: bool = False,
         
     | 
| 441 | 
         
            +
                ):
         
     | 
| 442 | 
         
            +
                    super().__init__()
         
     | 
| 443 | 
         
            +
                    self.clip_audio = clip_audio
         
     | 
| 444 | 
         
            +
                    self.out = nn.Linear(dim, mdct_frame_len)
         
     | 
| 445 | 
         
            +
                    self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
         
     | 
| 446 | 
         
            +
             
     | 
| 447 | 
         
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         
     | 
| 448 | 
         
            +
                    """
         
     | 
| 449 | 
         
            +
                    Forward pass of the IMDCTCosHead module.
         
     | 
| 450 | 
         
            +
             
     | 
| 451 | 
         
            +
                    Args:
         
     | 
| 452 | 
         
            +
                        x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
         
     | 
| 453 | 
         
            +
                                    L is the sequence length, and H denotes the model dimension.
         
     | 
| 454 | 
         
            +
             
     | 
| 455 | 
         
            +
                    Returns:
         
     | 
| 456 | 
         
            +
                        Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
         
     | 
| 457 | 
         
            +
                    """
         
     | 
| 458 | 
         
            +
                    x = self.out(x)
         
     | 
| 459 | 
         
            +
                    m, p = x.chunk(2, dim=2)
         
     | 
| 460 | 
         
            +
                    m = torch.exp(m).clip(
         
     | 
| 461 | 
         
            +
                        max=1e2
         
     | 
| 462 | 
         
            +
                    )  # safeguard to prevent excessively large magnitudes
         
     | 
| 463 | 
         
            +
                    audio = self.imdct(m * torch.cos(p))
         
     | 
| 464 | 
         
            +
                    if self.clip_audio:
         
     | 
| 465 | 
         
            +
                        audio = torch.clip(x, min=-1.0, max=1.0)
         
     | 
| 466 | 
         
            +
                    return audio
         
     | 
| 467 | 
         
            +
             
     | 
| 468 | 
         
            +
             
     | 
| 469 | 
         
            +
            class ConvNeXtBlock(nn.Module):
         
     | 
| 470 | 
         
            +
                """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
         
     | 
| 471 | 
         
            +
             
     | 
| 472 | 
         
            +
                Args:
         
     | 
| 473 | 
         
            +
                    dim (int): Number of input channels.
         
     | 
| 474 | 
         
            +
                    intermediate_dim (int): Dimensionality of the intermediate layer.
         
     | 
| 475 | 
         
            +
                    layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
         
     | 
| 476 | 
         
            +
                        Defaults to None.
         
     | 
| 477 | 
         
            +
                    adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
         
     | 
| 478 | 
         
            +
                        None means non-conditional LayerNorm. Defaults to None.
         
     | 
| 479 | 
         
            +
                """
         
     | 
| 480 | 
         
            +
             
     | 
| 481 | 
         
            +
                def __init__(
         
     | 
| 482 | 
         
            +
                    self,
         
     | 
| 483 | 
         
            +
                    dim: int,
         
     | 
| 484 | 
         
            +
                    intermediate_dim: int,
         
     | 
| 485 | 
         
            +
                    layer_scale_init_value: float,
         
     | 
| 486 | 
         
            +
                    adanorm_num_embeddings: Optional[int] = None,
         
     | 
| 487 | 
         
            +
                ):
         
     | 
| 488 | 
         
            +
                    super().__init__()
         
     | 
| 489 | 
         
            +
                    self.dwconv = nn.Conv1d(
         
     | 
| 490 | 
         
            +
                        dim, dim, kernel_size=7, padding=3, groups=dim
         
     | 
| 491 | 
         
            +
                    )  # depthwise conv
         
     | 
| 492 | 
         
            +
                    self.adanorm = adanorm_num_embeddings is not None
         
     | 
| 493 | 
         
            +
                    if adanorm_num_embeddings:
         
     | 
| 494 | 
         
            +
                        self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
         
     | 
| 495 | 
         
            +
                    else:
         
     | 
| 496 | 
         
            +
                        self.norm = nn.LayerNorm(dim, eps=1e-6)
         
     | 
| 497 | 
         
            +
                    self.pwconv1 = nn.Linear(
         
     | 
| 498 | 
         
            +
                        dim, intermediate_dim
         
     | 
| 499 | 
         
            +
                    )  # pointwise/1x1 convs, implemented with linear layers
         
     | 
| 500 | 
         
            +
                    self.act = nn.GELU()
         
     | 
| 501 | 
         
            +
                    self.pwconv2 = nn.Linear(intermediate_dim, dim)
         
     | 
| 502 | 
         
            +
                    self.gamma = (
         
     | 
| 503 | 
         
            +
                        nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
         
     | 
| 504 | 
         
            +
                        if layer_scale_init_value > 0
         
     | 
| 505 | 
         
            +
                        else None
         
     | 
| 506 | 
         
            +
                    )
         
     | 
| 507 | 
         
            +
             
     | 
| 508 | 
         
            +
                def forward(
         
     | 
| 509 | 
         
            +
                    self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
         
     | 
| 510 | 
         
            +
                ) -> torch.Tensor:
         
     | 
| 511 | 
         
            +
                    residual = x
         
     | 
| 512 | 
         
            +
                    x = self.dwconv(x)
         
     | 
| 513 | 
         
            +
                    x = x.transpose(1, 2)  # (B, C, T) -> (B, T, C)
         
     | 
| 514 | 
         
            +
                    if self.adanorm:
         
     | 
| 515 | 
         
            +
                        assert cond_embedding_id is not None
         
     | 
| 516 | 
         
            +
                        x = self.norm(x, cond_embedding_id)
         
     | 
| 517 | 
         
            +
                    else:
         
     | 
| 518 | 
         
            +
                        x = self.norm(x)
         
     | 
| 519 | 
         
            +
                    x = self.pwconv1(x)
         
     | 
| 520 | 
         
            +
                    x = self.act(x)
         
     | 
| 521 | 
         
            +
                    x = self.pwconv2(x)
         
     | 
| 522 | 
         
            +
                    if self.gamma is not None:
         
     | 
| 523 | 
         
            +
                        x = self.gamma * x
         
     | 
| 524 | 
         
            +
                    x = x.transpose(1, 2)  # (B, T, C) -> (B, C, T)
         
     | 
| 525 | 
         
            +
             
     | 
| 526 | 
         
            +
                    x = residual + x
         
     | 
| 527 | 
         
            +
                    return x
         
     | 
| 528 | 
         
            +
             
     | 
| 529 | 
         
            +
             
     | 
| 530 | 
         
            +
            class AdaLayerNorm(nn.Module):
         
     | 
| 531 | 
         
            +
                """
         
     | 
| 532 | 
         
            +
                Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
         
     | 
| 533 | 
         
            +
             
     | 
| 534 | 
         
            +
                Args:
         
     | 
| 535 | 
         
            +
                    num_embeddings (int): Number of embeddings.
         
     | 
| 536 | 
         
            +
                    embedding_dim (int): Dimension of the embeddings.
         
     | 
| 537 | 
         
            +
                """
         
     | 
| 538 | 
         
            +
             
     | 
| 539 | 
         
            +
                def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6):
         
     | 
| 540 | 
         
            +
                    super().__init__()
         
     | 
| 541 | 
         
            +
                    self.eps = eps
         
     | 
| 542 | 
         
            +
                    self.dim = embedding_dim
         
     | 
| 543 | 
         
            +
                    self.scale = nn.Embedding(
         
     | 
| 544 | 
         
            +
                        num_embeddings=num_embeddings, embedding_dim=embedding_dim
         
     | 
| 545 | 
         
            +
                    )
         
     | 
| 546 | 
         
            +
                    self.shift = nn.Embedding(
         
     | 
| 547 | 
         
            +
                        num_embeddings=num_embeddings, embedding_dim=embedding_dim
         
     | 
| 548 | 
         
            +
                    )
         
     | 
| 549 | 
         
            +
                    torch.nn.init.ones_(self.scale.weight)
         
     | 
| 550 | 
         
            +
                    torch.nn.init.zeros_(self.shift.weight)
         
     | 
| 551 | 
         
            +
             
     | 
| 552 | 
         
            +
                def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor:
         
     | 
| 553 | 
         
            +
                    scale = self.scale(cond_embedding_id)
         
     | 
| 554 | 
         
            +
                    shift = self.shift(cond_embedding_id)
         
     | 
| 555 | 
         
            +
                    x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
         
     | 
| 556 | 
         
            +
                    x = x * scale + shift
         
     | 
| 557 | 
         
            +
                    return x
         
     | 
| 558 | 
         
            +
             
     | 
| 559 | 
         
            +
             
     | 
| 560 | 
         
            +
            class ResBlock1(nn.Module):
         
     | 
| 561 | 
         
            +
                """
         
     | 
| 562 | 
         
            +
                ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,
         
     | 
| 563 | 
         
            +
                but without upsampling layers.
         
     | 
| 564 | 
         
            +
             
     | 
| 565 | 
         
            +
                Args:
         
     | 
| 566 | 
         
            +
                    dim (int): Number of input channels.
         
     | 
| 567 | 
         
            +
                    kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
         
     | 
| 568 | 
         
            +
                    dilation (tuple[int], optional): Dilation factors for the dilated convolutions.
         
     | 
| 569 | 
         
            +
                        Defaults to (1, 3, 5).
         
     | 
| 570 | 
         
            +
                    lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.
         
     | 
| 571 | 
         
            +
                        Defaults to 0.1.
         
     | 
| 572 | 
         
            +
                    layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
         
     | 
| 573 | 
         
            +
                        Defaults to None.
         
     | 
| 574 | 
         
            +
                """
         
     | 
| 575 | 
         
            +
             
     | 
| 576 | 
         
            +
                def __init__(
         
     | 
| 577 | 
         
            +
                    self,
         
     | 
| 578 | 
         
            +
                    dim: int,
         
     | 
| 579 | 
         
            +
                    kernel_size: int = 3,
         
     | 
| 580 | 
         
            +
                    dilation: Tuple[int, int, int] = (1, 3, 5),
         
     | 
| 581 | 
         
            +
                    lrelu_slope: float = 0.1,
         
     | 
| 582 | 
         
            +
                    layer_scale_init_value: Optional[float] = None,
         
     | 
| 583 | 
         
            +
                ):
         
     | 
| 584 | 
         
            +
                    super().__init__()
         
     | 
| 585 | 
         
            +
                    self.lrelu_slope = lrelu_slope
         
     | 
| 586 | 
         
            +
                    self.convs1 = nn.ModuleList(
         
     | 
| 587 | 
         
            +
                        [
         
     | 
| 588 | 
         
            +
                            weight_norm(
         
     | 
| 589 | 
         
            +
                                nn.Conv1d(
         
     | 
| 590 | 
         
            +
                                    dim,
         
     | 
| 591 | 
         
            +
                                    dim,
         
     | 
| 592 | 
         
            +
                                    kernel_size,
         
     | 
| 593 | 
         
            +
                                    1,
         
     | 
| 594 | 
         
            +
                                    dilation=dilation[0],
         
     | 
| 595 | 
         
            +
                                    padding=self.get_padding(kernel_size, dilation[0]),
         
     | 
| 596 | 
         
            +
                                )
         
     | 
| 597 | 
         
            +
                            ),
         
     | 
| 598 | 
         
            +
                            weight_norm(
         
     | 
| 599 | 
         
            +
                                nn.Conv1d(
         
     | 
| 600 | 
         
            +
                                    dim,
         
     | 
| 601 | 
         
            +
                                    dim,
         
     | 
| 602 | 
         
            +
                                    kernel_size,
         
     | 
| 603 | 
         
            +
                                    1,
         
     | 
| 604 | 
         
            +
                                    dilation=dilation[1],
         
     | 
| 605 | 
         
            +
                                    padding=self.get_padding(kernel_size, dilation[1]),
         
     | 
| 606 | 
         
            +
                                )
         
     | 
| 607 | 
         
            +
                            ),
         
     | 
| 608 | 
         
            +
                            weight_norm(
         
     | 
| 609 | 
         
            +
                                nn.Conv1d(
         
     | 
| 610 | 
         
            +
                                    dim,
         
     | 
| 611 | 
         
            +
                                    dim,
         
     | 
| 612 | 
         
            +
                                    kernel_size,
         
     | 
| 613 | 
         
            +
                                    1,
         
     | 
| 614 | 
         
            +
                                    dilation=dilation[2],
         
     | 
| 615 | 
         
            +
                                    padding=self.get_padding(kernel_size, dilation[2]),
         
     | 
| 616 | 
         
            +
                                )
         
     | 
| 617 | 
         
            +
                            ),
         
     | 
| 618 | 
         
            +
                        ]
         
     | 
| 619 | 
         
            +
                    )
         
     | 
| 620 | 
         
            +
             
     | 
| 621 | 
         
            +
                    self.convs2 = nn.ModuleList(
         
     | 
| 622 | 
         
            +
                        [
         
     | 
| 623 | 
         
            +
                            weight_norm(
         
     | 
| 624 | 
         
            +
                                nn.Conv1d(
         
     | 
| 625 | 
         
            +
                                    dim,
         
     | 
| 626 | 
         
            +
                                    dim,
         
     | 
| 627 | 
         
            +
                                    kernel_size,
         
     | 
| 628 | 
         
            +
                                    1,
         
     | 
| 629 | 
         
            +
                                    dilation=1,
         
     | 
| 630 | 
         
            +
                                    padding=self.get_padding(kernel_size, 1),
         
     | 
| 631 | 
         
            +
                                )
         
     | 
| 632 | 
         
            +
                            ),
         
     | 
| 633 | 
         
            +
                            weight_norm(
         
     | 
| 634 | 
         
            +
                                nn.Conv1d(
         
     | 
| 635 | 
         
            +
                                    dim,
         
     | 
| 636 | 
         
            +
                                    dim,
         
     | 
| 637 | 
         
            +
                                    kernel_size,
         
     | 
| 638 | 
         
            +
                                    1,
         
     | 
| 639 | 
         
            +
                                    dilation=1,
         
     | 
| 640 | 
         
            +
                                    padding=self.get_padding(kernel_size, 1),
         
     | 
| 641 | 
         
            +
                                )
         
     | 
| 642 | 
         
            +
                            ),
         
     | 
| 643 | 
         
            +
                            weight_norm(
         
     | 
| 644 | 
         
            +
                                nn.Conv1d(
         
     | 
| 645 | 
         
            +
                                    dim,
         
     | 
| 646 | 
         
            +
                                    dim,
         
     | 
| 647 | 
         
            +
                                    kernel_size,
         
     | 
| 648 | 
         
            +
                                    1,
         
     | 
| 649 | 
         
            +
                                    dilation=1,
         
     | 
| 650 | 
         
            +
                                    padding=self.get_padding(kernel_size, 1),
         
     | 
| 651 | 
         
            +
                                )
         
     | 
| 652 | 
         
            +
                            ),
         
     | 
| 653 | 
         
            +
                        ]
         
     | 
| 654 | 
         
            +
                    )
         
     | 
| 655 | 
         
            +
             
     | 
| 656 | 
         
            +
                    self.gamma = nn.ParameterList(
         
     | 
| 657 | 
         
            +
                        [
         
     | 
| 658 | 
         
            +
                            (
         
     | 
| 659 | 
         
            +
                                nn.Parameter(
         
     | 
| 660 | 
         
            +
                                    layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
         
     | 
| 661 | 
         
            +
                                )
         
     | 
| 662 | 
         
            +
                                if layer_scale_init_value is not None
         
     | 
| 663 | 
         
            +
                                else None
         
     | 
| 664 | 
         
            +
                            ),
         
     | 
| 665 | 
         
            +
                            (
         
     | 
| 666 | 
         
            +
                                nn.Parameter(
         
     | 
| 667 | 
         
            +
                                    layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
         
     | 
| 668 | 
         
            +
                                )
         
     | 
| 669 | 
         
            +
                                if layer_scale_init_value is not None
         
     | 
| 670 | 
         
            +
                                else None
         
     | 
| 671 | 
         
            +
                            ),
         
     | 
| 672 | 
         
            +
                            (
         
     | 
| 673 | 
         
            +
                                nn.Parameter(
         
     | 
| 674 | 
         
            +
                                    layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
         
     | 
| 675 | 
         
            +
                                )
         
     | 
| 676 | 
         
            +
                                if layer_scale_init_value is not None
         
     | 
| 677 | 
         
            +
                                else None
         
     | 
| 678 | 
         
            +
                            ),
         
     | 
| 679 | 
         
            +
                        ]
         
     | 
| 680 | 
         
            +
                    )
         
     | 
| 681 | 
         
            +
             
     | 
| 682 | 
         
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         
     | 
| 683 | 
         
            +
                    for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma):
         
     | 
| 684 | 
         
            +
                        xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope)
         
     | 
| 685 | 
         
            +
                        xt = c1(xt)
         
     | 
| 686 | 
         
            +
                        xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope)
         
     | 
| 687 | 
         
            +
                        xt = c2(xt)
         
     | 
| 688 | 
         
            +
                        if gamma is not None:
         
     | 
| 689 | 
         
            +
                            xt = gamma * xt
         
     | 
| 690 | 
         
            +
                        x = xt + x
         
     | 
| 691 | 
         
            +
                    return x
         
     | 
| 692 | 
         
            +
             
     | 
| 693 | 
         
            +
                def remove_weight_norm(self):
         
     | 
| 694 | 
         
            +
                    for l in self.convs1:
         
     | 
| 695 | 
         
            +
                        remove_weight_norm(l)
         
     | 
| 696 | 
         
            +
                    for l in self.convs2:
         
     | 
| 697 | 
         
            +
                        remove_weight_norm(l)
         
     | 
| 698 | 
         
            +
             
     | 
| 699 | 
         
            +
                @staticmethod
         
     | 
| 700 | 
         
            +
                def get_padding(kernel_size: int, dilation: int = 1) -> int:
         
     | 
| 701 | 
         
            +
                    return int((kernel_size * dilation - dilation) / 2)
         
     | 
| 702 | 
         
            +
             
     | 
| 703 | 
         
            +
             
     | 
| 704 | 
         
            +
            class Backbone(nn.Module):
         
     | 
| 705 | 
         
            +
                """Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
         
     | 
| 706 | 
         
            +
             
     | 
| 707 | 
         
            +
                def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
         
     | 
| 708 | 
         
            +
                    """
         
     | 
| 709 | 
         
            +
                    Args:
         
     | 
| 710 | 
         
            +
                        x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
         
     | 
| 711 | 
         
            +
                                    C denotes output features, and L is the sequence length.
         
     | 
| 712 | 
         
            +
             
     | 
| 713 | 
         
            +
                    Returns:
         
     | 
| 714 | 
         
            +
                        Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
         
     | 
| 715 | 
         
            +
                                and H denotes the model dimension.
         
     | 
| 716 | 
         
            +
                    """
         
     | 
| 717 | 
         
            +
                    raise NotImplementedError("Subclasses must implement the forward method.")
         
     | 
| 718 | 
         
            +
             
     | 
| 719 | 
         
            +
             
     | 
| 720 | 
         
            +
            class VocosBackbone(Backbone):
         
     | 
| 721 | 
         
            +
                """
         
     | 
| 722 | 
         
            +
                Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
         
     | 
| 723 | 
         
            +
             
     | 
| 724 | 
         
            +
                Args:
         
     | 
| 725 | 
         
            +
                    input_channels (int): Number of input features channels.
         
     | 
| 726 | 
         
            +
                    dim (int): Hidden dimension of the model.
         
     | 
| 727 | 
         
            +
                    intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
         
     | 
| 728 | 
         
            +
                    num_layers (int): Number of ConvNeXtBlock layers.
         
     | 
| 729 | 
         
            +
                    layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
         
     | 
| 730 | 
         
            +
                    adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
         
     | 
| 731 | 
         
            +
                                                            None means non-conditional model. Defaults to None.
         
     | 
| 732 | 
         
            +
                """
         
     | 
| 733 | 
         
            +
             
     | 
| 734 | 
         
            +
                def __init__(
         
     | 
| 735 | 
         
            +
                    self,
         
     | 
| 736 | 
         
            +
                    input_channels: int,
         
     | 
| 737 | 
         
            +
                    dim: int,
         
     | 
| 738 | 
         
            +
                    intermediate_dim: int,
         
     | 
| 739 | 
         
            +
                    num_layers: int,
         
     | 
| 740 | 
         
            +
                    layer_scale_init_value: Optional[float] = None,
         
     | 
| 741 | 
         
            +
                    adanorm_num_embeddings: Optional[int] = None,
         
     | 
| 742 | 
         
            +
                ):
         
     | 
| 743 | 
         
            +
                    super().__init__()
         
     | 
| 744 | 
         
            +
                    self.input_channels = input_channels
         
     | 
| 745 | 
         
            +
                    self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3)
         
     | 
| 746 | 
         
            +
                    self.adanorm = adanorm_num_embeddings is not None
         
     | 
| 747 | 
         
            +
                    if adanorm_num_embeddings:
         
     | 
| 748 | 
         
            +
                        self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
         
     | 
| 749 | 
         
            +
                    else:
         
     | 
| 750 | 
         
            +
                        self.norm = nn.LayerNorm(dim, eps=1e-6)
         
     | 
| 751 | 
         
            +
                    layer_scale_init_value = layer_scale_init_value or 1 / num_layers
         
     | 
| 752 | 
         
            +
                    self.convnext = nn.ModuleList(
         
     | 
| 753 | 
         
            +
                        [
         
     | 
| 754 | 
         
            +
                            ConvNeXtBlock(
         
     | 
| 755 | 
         
            +
                                dim=dim,
         
     | 
| 756 | 
         
            +
                                intermediate_dim=intermediate_dim,
         
     | 
| 757 | 
         
            +
                                layer_scale_init_value=layer_scale_init_value,
         
     | 
| 758 | 
         
            +
                                adanorm_num_embeddings=adanorm_num_embeddings,
         
     | 
| 759 | 
         
            +
                            )
         
     | 
| 760 | 
         
            +
                            for _ in range(num_layers)
         
     | 
| 761 | 
         
            +
                        ]
         
     | 
| 762 | 
         
            +
                    )
         
     | 
| 763 | 
         
            +
                    self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
         
     | 
| 764 | 
         
            +
                    self.apply(self._init_weights)
         
     | 
| 765 | 
         
            +
             
     | 
| 766 | 
         
            +
                def _init_weights(self, m):
         
     | 
| 767 | 
         
            +
                    if isinstance(m, (nn.Conv1d, nn.Linear)):
         
     | 
| 768 | 
         
            +
                        nn.init.trunc_normal_(m.weight, std=0.02)
         
     | 
| 769 | 
         
            +
                        nn.init.constant_(m.bias, 0)
         
     | 
| 770 | 
         
            +
             
     | 
| 771 | 
         
            +
                def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
         
     | 
| 772 | 
         
            +
                    bandwidth_id = kwargs.get("bandwidth_id", None)
         
     | 
| 773 | 
         
            +
                    x = self.embed(x)
         
     | 
| 774 | 
         
            +
                    if self.adanorm:
         
     | 
| 775 | 
         
            +
                        assert bandwidth_id is not None
         
     | 
| 776 | 
         
            +
                        x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id)
         
     | 
| 777 | 
         
            +
                    else:
         
     | 
| 778 | 
         
            +
                        x = self.norm(x.transpose(1, 2))
         
     | 
| 779 | 
         
            +
                    x = x.transpose(1, 2)
         
     | 
| 780 | 
         
            +
                    for conv_block in self.convnext:
         
     | 
| 781 | 
         
            +
                        x = conv_block(x, cond_embedding_id=bandwidth_id)
         
     | 
| 782 | 
         
            +
                    x = self.final_layer_norm(x.transpose(1, 2))
         
     | 
| 783 | 
         
            +
                    return x
         
     | 
| 784 | 
         
            +
             
     | 
| 785 | 
         
            +
             
     | 
| 786 | 
         
            +
            class VocosResNetBackbone(Backbone):
         
     | 
| 787 | 
         
            +
                """
         
     | 
| 788 | 
         
            +
                Vocos backbone module built with ResBlocks.
         
     | 
| 789 | 
         
            +
             
     | 
| 790 | 
         
            +
                Args:
         
     | 
| 791 | 
         
            +
                    input_channels (int): Number of input features channels.
         
     | 
| 792 | 
         
            +
                    dim (int): Hidden dimension of the model.
         
     | 
| 793 | 
         
            +
                    num_blocks (int): Number of ResBlock1 blocks.
         
     | 
| 794 | 
         
            +
                    layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None.
         
     | 
| 795 | 
         
            +
                """
         
     | 
| 796 | 
         
            +
             
     | 
| 797 | 
         
            +
                def __init__(
         
     | 
| 798 | 
         
            +
                    self,
         
     | 
| 799 | 
         
            +
                    input_channels,
         
     | 
| 800 | 
         
            +
                    dim,
         
     | 
| 801 | 
         
            +
                    num_blocks,
         
     | 
| 802 | 
         
            +
                    layer_scale_init_value=None,
         
     | 
| 803 | 
         
            +
                ):
         
     | 
| 804 | 
         
            +
                    super().__init__()
         
     | 
| 805 | 
         
            +
                    self.input_channels = input_channels
         
     | 
| 806 | 
         
            +
                    self.embed = weight_norm(
         
     | 
| 807 | 
         
            +
                        nn.Conv1d(input_channels, dim, kernel_size=3, padding=1)
         
     | 
| 808 | 
         
            +
                    )
         
     | 
| 809 | 
         
            +
                    layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3
         
     | 
| 810 | 
         
            +
                    self.resnet = nn.Sequential(
         
     | 
| 811 | 
         
            +
                        *[
         
     | 
| 812 | 
         
            +
                            ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value)
         
     | 
| 813 | 
         
            +
                            for _ in range(num_blocks)
         
     | 
| 814 | 
         
            +
                        ]
         
     | 
| 815 | 
         
            +
                    )
         
     | 
| 816 | 
         
            +
             
     | 
| 817 | 
         
            +
                def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
         
     | 
| 818 | 
         
            +
                    x = self.embed(x)
         
     | 
| 819 | 
         
            +
                    x = self.resnet(x)
         
     | 
| 820 | 
         
            +
                    x = x.transpose(1, 2)
         
     | 
| 821 | 
         
            +
                    return x
         
     | 
| 822 | 
         
            +
             
     | 
| 823 | 
         
            +
             
     | 
| 824 | 
         
            +
            class Vocos(nn.Module):
         
     | 
| 825 | 
         
            +
                def __init__(
         
     | 
| 826 | 
         
            +
                    self,
         
     | 
| 827 | 
         
            +
                    input_channels: int = 256,
         
     | 
| 828 | 
         
            +
                    dim: int = 384,
         
     | 
| 829 | 
         
            +
                    intermediate_dim: int = 1152,
         
     | 
| 830 | 
         
            +
                    num_layers: int = 8,
         
     | 
| 831 | 
         
            +
                    n_fft: int = 800,
         
     | 
| 832 | 
         
            +
                    hop_size: int = 200,
         
     | 
| 833 | 
         
            +
                    padding: str = "same",
         
     | 
| 834 | 
         
            +
                    adanorm_num_embeddings=None,
         
     | 
| 835 | 
         
            +
                    cfg=None,
         
     | 
| 836 | 
         
            +
                ):
         
     | 
| 837 | 
         
            +
                    super().__init__()
         
     | 
| 838 | 
         
            +
             
     | 
| 839 | 
         
            +
                    input_channels = (
         
     | 
| 840 | 
         
            +
                        cfg.input_channels
         
     | 
| 841 | 
         
            +
                        if cfg is not None and hasattr(cfg, "input_channels")
         
     | 
| 842 | 
         
            +
                        else input_channels
         
     | 
| 843 | 
         
            +
                    )
         
     | 
| 844 | 
         
            +
                    dim = cfg.dim if cfg is not None and hasattr(cfg, "dim") else dim
         
     | 
| 845 | 
         
            +
                    intermediate_dim = (
         
     | 
| 846 | 
         
            +
                        cfg.intermediate_dim
         
     | 
| 847 | 
         
            +
                        if cfg is not None and hasattr(cfg, "intermediate_dim")
         
     | 
| 848 | 
         
            +
                        else intermediate_dim
         
     | 
| 849 | 
         
            +
                    )
         
     | 
| 850 | 
         
            +
                    num_layers = (
         
     | 
| 851 | 
         
            +
                        cfg.num_layers
         
     | 
| 852 | 
         
            +
                        if cfg is not None and hasattr(cfg, "num_layers")
         
     | 
| 853 | 
         
            +
                        else num_layers
         
     | 
| 854 | 
         
            +
                    )
         
     | 
| 855 | 
         
            +
                    adanorm_num_embeddings = (
         
     | 
| 856 | 
         
            +
                        cfg.adanorm_num_embeddings
         
     | 
| 857 | 
         
            +
                        if cfg is not None and hasattr(cfg, "adanorm_num_embeddings")
         
     | 
| 858 | 
         
            +
                        else adanorm_num_embeddings
         
     | 
| 859 | 
         
            +
                    )
         
     | 
| 860 | 
         
            +
                    n_fft = cfg.n_fft if cfg is not None and hasattr(cfg, "n_fft") else n_fft
         
     | 
| 861 | 
         
            +
                    hop_size = (
         
     | 
| 862 | 
         
            +
                        cfg.hop_size if cfg is not None and hasattr(cfg, "hop_size") else hop_size
         
     | 
| 863 | 
         
            +
                    )
         
     | 
| 864 | 
         
            +
                    padding = (
         
     | 
| 865 | 
         
            +
                        cfg.padding if cfg is not None and hasattr(cfg, "padding") else padding
         
     | 
| 866 | 
         
            +
                    )
         
     | 
| 867 | 
         
            +
             
     | 
| 868 | 
         
            +
                    self.backbone = VocosBackbone(
         
     | 
| 869 | 
         
            +
                        input_channels=input_channels,
         
     | 
| 870 | 
         
            +
                        dim=dim,
         
     | 
| 871 | 
         
            +
                        intermediate_dim=intermediate_dim,
         
     | 
| 872 | 
         
            +
                        num_layers=num_layers,
         
     | 
| 873 | 
         
            +
                        adanorm_num_embeddings=adanorm_num_embeddings,
         
     | 
| 874 | 
         
            +
                    )
         
     | 
| 875 | 
         
            +
                    self.head = ISTFTHead(dim, n_fft, hop_size, padding)
         
     | 
| 876 | 
         
            +
             
     | 
| 877 | 
         
            +
                def forward(self, x):
         
     | 
| 878 | 
         
            +
                    x = self.backbone(x)
         
     | 
| 879 | 
         
            +
                    x = self.head(x)
         
     | 
| 880 | 
         
            +
             
     | 
| 881 | 
         
            +
                    return x[:, None, :]
         
     | 
    	
        models/codec/codec_dataset.py
    ADDED
    
    | 
         @@ -0,0 +1,264 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2023 Amphion.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            from typing import Iterable
         
     | 
| 7 | 
         
            +
            import torch
         
     | 
| 8 | 
         
            +
            import numpy as np
         
     | 
| 9 | 
         
            +
            import torch.utils.data
         
     | 
| 10 | 
         
            +
            from torch.nn.utils.rnn import pad_sequence
         
     | 
| 11 | 
         
            +
            from utils.data_utils import *
         
     | 
| 12 | 
         
            +
            from torch.utils.data import ConcatDataset, Dataset
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            class CodecDataset(torch.utils.data.Dataset):
         
     | 
| 16 | 
         
            +
                def __init__(self, cfg, dataset, is_valid=False):
         
     | 
| 17 | 
         
            +
                    """
         
     | 
| 18 | 
         
            +
                    Args:
         
     | 
| 19 | 
         
            +
                        cfg: config
         
     | 
| 20 | 
         
            +
                        dataset: dataset name
         
     | 
| 21 | 
         
            +
                        is_valid: whether to use train or valid dataset
         
     | 
| 22 | 
         
            +
                    """
         
     | 
| 23 | 
         
            +
                    assert isinstance(dataset, str)
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
                    processed_data_dir = os.path.join(cfg.preprocess.processed_dir, dataset)
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                    meta_file = cfg.preprocess.valid_file if is_valid else cfg.preprocess.train_file
         
     | 
| 28 | 
         
            +
                    self.metafile_path = os.path.join(processed_data_dir, meta_file)
         
     | 
| 29 | 
         
            +
                    self.metadata = self.get_metadata()
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                    self.data_root = processed_data_dir
         
     | 
| 32 | 
         
            +
                    self.cfg = cfg
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                    if cfg.preprocess.use_audio:
         
     | 
| 35 | 
         
            +
                        self.utt2audio_path = {}
         
     | 
| 36 | 
         
            +
                        for utt_info in self.metadata:
         
     | 
| 37 | 
         
            +
                            dataset = utt_info["Dataset"]
         
     | 
| 38 | 
         
            +
                            uid = utt_info["Uid"]
         
     | 
| 39 | 
         
            +
                            utt = "{}_{}".format(dataset, uid)
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                            self.utt2audio_path[utt] = os.path.join(
         
     | 
| 42 | 
         
            +
                                cfg.preprocess.processed_dir,
         
     | 
| 43 | 
         
            +
                                dataset,
         
     | 
| 44 | 
         
            +
                                cfg.preprocess.audio_dir,
         
     | 
| 45 | 
         
            +
                                uid + ".npy",
         
     | 
| 46 | 
         
            +
                            )
         
     | 
| 47 | 
         
            +
                    elif cfg.preprocess.use_label:
         
     | 
| 48 | 
         
            +
                        self.utt2label_path = {}
         
     | 
| 49 | 
         
            +
                        for utt_info in self.metadata:
         
     | 
| 50 | 
         
            +
                            dataset = utt_info["Dataset"]
         
     | 
| 51 | 
         
            +
                            uid = utt_info["Uid"]
         
     | 
| 52 | 
         
            +
                            utt = "{}_{}".format(dataset, uid)
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                            self.utt2label_path[utt] = os.path.join(
         
     | 
| 55 | 
         
            +
                                cfg.preprocess.processed_dir,
         
     | 
| 56 | 
         
            +
                                dataset,
         
     | 
| 57 | 
         
            +
                                cfg.preprocess.label_dir,
         
     | 
| 58 | 
         
            +
                                uid + ".npy",
         
     | 
| 59 | 
         
            +
                            )
         
     | 
| 60 | 
         
            +
                    elif cfg.preprocess.use_one_hot:
         
     | 
| 61 | 
         
            +
                        self.utt2one_hot_path = {}
         
     | 
| 62 | 
         
            +
                        for utt_info in self.metadata:
         
     | 
| 63 | 
         
            +
                            dataset = utt_info["Dataset"]
         
     | 
| 64 | 
         
            +
                            uid = utt_info["Uid"]
         
     | 
| 65 | 
         
            +
                            utt = "{}_{}".format(dataset, uid)
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
                            self.utt2one_hot_path[utt] = os.path.join(
         
     | 
| 68 | 
         
            +
                                cfg.preprocess.processed_dir,
         
     | 
| 69 | 
         
            +
                                dataset,
         
     | 
| 70 | 
         
            +
                                cfg.preprocess.one_hot_dir,
         
     | 
| 71 | 
         
            +
                                uid + ".npy",
         
     | 
| 72 | 
         
            +
                            )
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                    if cfg.preprocess.use_mel:
         
     | 
| 75 | 
         
            +
                        self.utt2mel_path = {}
         
     | 
| 76 | 
         
            +
                        for utt_info in self.metadata:
         
     | 
| 77 | 
         
            +
                            dataset = utt_info["Dataset"]
         
     | 
| 78 | 
         
            +
                            uid = utt_info["Uid"]
         
     | 
| 79 | 
         
            +
                            utt = "{}_{}".format(dataset, uid)
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                            self.utt2mel_path[utt] = os.path.join(
         
     | 
| 82 | 
         
            +
                                cfg.preprocess.processed_dir,
         
     | 
| 83 | 
         
            +
                                dataset,
         
     | 
| 84 | 
         
            +
                                cfg.preprocess.mel_dir,
         
     | 
| 85 | 
         
            +
                                uid + ".npy",
         
     | 
| 86 | 
         
            +
                            )
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                    if cfg.preprocess.use_frame_pitch:
         
     | 
| 89 | 
         
            +
                        self.utt2frame_pitch_path = {}
         
     | 
| 90 | 
         
            +
                        for utt_info in self.metadata:
         
     | 
| 91 | 
         
            +
                            dataset = utt_info["Dataset"]
         
     | 
| 92 | 
         
            +
                            uid = utt_info["Uid"]
         
     | 
| 93 | 
         
            +
                            utt = "{}_{}".format(dataset, uid)
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
                            self.utt2frame_pitch_path[utt] = os.path.join(
         
     | 
| 96 | 
         
            +
                                cfg.preprocess.processed_dir,
         
     | 
| 97 | 
         
            +
                                dataset,
         
     | 
| 98 | 
         
            +
                                cfg.preprocess.pitch_dir,
         
     | 
| 99 | 
         
            +
                                uid + ".npy",
         
     | 
| 100 | 
         
            +
                            )
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                    if cfg.preprocess.use_uv:
         
     | 
| 103 | 
         
            +
                        self.utt2uv_path = {}
         
     | 
| 104 | 
         
            +
                        for utt_info in self.metadata:
         
     | 
| 105 | 
         
            +
                            dataset = utt_info["Dataset"]
         
     | 
| 106 | 
         
            +
                            uid = utt_info["Uid"]
         
     | 
| 107 | 
         
            +
                            utt = "{}_{}".format(dataset, uid)
         
     | 
| 108 | 
         
            +
                            self.utt2uv_path[utt] = os.path.join(
         
     | 
| 109 | 
         
            +
                                cfg.preprocess.processed_dir,
         
     | 
| 110 | 
         
            +
                                dataset,
         
     | 
| 111 | 
         
            +
                                cfg.preprocess.uv_dir,
         
     | 
| 112 | 
         
            +
                                uid + ".npy",
         
     | 
| 113 | 
         
            +
                            )
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
                    if cfg.preprocess.use_amplitude_phase:
         
     | 
| 116 | 
         
            +
                        self.utt2logamp_path = {}
         
     | 
| 117 | 
         
            +
                        self.utt2pha_path = {}
         
     | 
| 118 | 
         
            +
                        self.utt2rea_path = {}
         
     | 
| 119 | 
         
            +
                        self.utt2imag_path = {}
         
     | 
| 120 | 
         
            +
                        for utt_info in self.metadata:
         
     | 
| 121 | 
         
            +
                            dataset = utt_info["Dataset"]
         
     | 
| 122 | 
         
            +
                            uid = utt_info["Uid"]
         
     | 
| 123 | 
         
            +
                            utt = "{}_{}".format(dataset, uid)
         
     | 
| 124 | 
         
            +
                            self.utt2logamp_path[utt] = os.path.join(
         
     | 
| 125 | 
         
            +
                                cfg.preprocess.processed_dir,
         
     | 
| 126 | 
         
            +
                                dataset,
         
     | 
| 127 | 
         
            +
                                cfg.preprocess.log_amplitude_dir,
         
     | 
| 128 | 
         
            +
                                uid + ".npy",
         
     | 
| 129 | 
         
            +
                            )
         
     | 
| 130 | 
         
            +
                            self.utt2pha_path[utt] = os.path.join(
         
     | 
| 131 | 
         
            +
                                cfg.preprocess.processed_dir,
         
     | 
| 132 | 
         
            +
                                dataset,
         
     | 
| 133 | 
         
            +
                                cfg.preprocess.phase_dir,
         
     | 
| 134 | 
         
            +
                                uid + ".npy",
         
     | 
| 135 | 
         
            +
                            )
         
     | 
| 136 | 
         
            +
                            self.utt2rea_path[utt] = os.path.join(
         
     | 
| 137 | 
         
            +
                                cfg.preprocess.processed_dir,
         
     | 
| 138 | 
         
            +
                                dataset,
         
     | 
| 139 | 
         
            +
                                cfg.preprocess.real_dir,
         
     | 
| 140 | 
         
            +
                                uid + ".npy",
         
     | 
| 141 | 
         
            +
                            )
         
     | 
| 142 | 
         
            +
                            self.utt2imag_path[utt] = os.path.join(
         
     | 
| 143 | 
         
            +
                                cfg.preprocess.processed_dir,
         
     | 
| 144 | 
         
            +
                                dataset,
         
     | 
| 145 | 
         
            +
                                cfg.preprocess.imaginary_dir,
         
     | 
| 146 | 
         
            +
                                uid + ".npy",
         
     | 
| 147 | 
         
            +
                            )
         
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
                def __getitem__(self, index):
         
     | 
| 150 | 
         
            +
                    utt_info = self.metadata[index]
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
                    dataset = utt_info["Dataset"]
         
     | 
| 153 | 
         
            +
                    uid = utt_info["Uid"]
         
     | 
| 154 | 
         
            +
                    utt = "{}_{}".format(dataset, uid)
         
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
                    single_feature = dict()
         
     | 
| 157 | 
         
            +
             
     | 
| 158 | 
         
            +
                    if self.cfg.preprocess.use_mel:
         
     | 
| 159 | 
         
            +
                        mel = np.load(self.utt2mel_path[utt])
         
     | 
| 160 | 
         
            +
                        assert mel.shape[0] == self.cfg.preprocess.n_mel  # [n_mels, T]
         
     | 
| 161 | 
         
            +
             
     | 
| 162 | 
         
            +
                        if "target_len" not in single_feature.keys():
         
     | 
| 163 | 
         
            +
                            single_feature["target_len"] = mel.shape[1]
         
     | 
| 164 | 
         
            +
             
     | 
| 165 | 
         
            +
                        single_feature["mel"] = mel
         
     | 
| 166 | 
         
            +
             
     | 
| 167 | 
         
            +
                    if self.cfg.preprocess.use_frame_pitch:
         
     | 
| 168 | 
         
            +
                        frame_pitch = np.load(self.utt2frame_pitch_path[utt])
         
     | 
| 169 | 
         
            +
             
     | 
| 170 | 
         
            +
                        if "target_len" not in single_feature.keys():
         
     | 
| 171 | 
         
            +
                            single_feature["target_len"] = len(frame_pitch)
         
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
                        aligned_frame_pitch = align_length(
         
     | 
| 174 | 
         
            +
                            frame_pitch, single_feature["target_len"]
         
     | 
| 175 | 
         
            +
                        )
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
                        single_feature["frame_pitch"] = aligned_frame_pitch
         
     | 
| 178 | 
         
            +
             
     | 
| 179 | 
         
            +
                    if self.cfg.preprocess.use_audio:
         
     | 
| 180 | 
         
            +
                        audio = np.load(self.utt2audio_path[utt])
         
     | 
| 181 | 
         
            +
             
     | 
| 182 | 
         
            +
                        single_feature["audio"] = audio
         
     | 
| 183 | 
         
            +
             
     | 
| 184 | 
         
            +
                    return single_feature
         
     | 
| 185 | 
         
            +
             
     | 
| 186 | 
         
            +
                def get_metadata(self):
         
     | 
| 187 | 
         
            +
                    with open(self.metafile_path, "r", encoding="utf-8") as f:
         
     | 
| 188 | 
         
            +
                        metadata = json.load(f)
         
     | 
| 189 | 
         
            +
             
     | 
| 190 | 
         
            +
                    return metadata
         
     | 
| 191 | 
         
            +
             
     | 
| 192 | 
         
            +
                def get_dataset_name(self):
         
     | 
| 193 | 
         
            +
                    return self.metadata[0]["Dataset"]
         
     | 
| 194 | 
         
            +
             
     | 
| 195 | 
         
            +
                def __len__(self):
         
     | 
| 196 | 
         
            +
                    return len(self.metadata)
         
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
            +
             
     | 
| 199 | 
         
            +
            class CodecConcatDataset(ConcatDataset):
         
     | 
| 200 | 
         
            +
                def __init__(self, datasets: Iterable[Dataset], full_audio_inference=False):
         
     | 
| 201 | 
         
            +
                    """Concatenate a series of datasets with their random inference audio merged."""
         
     | 
| 202 | 
         
            +
                    super().__init__(datasets)
         
     | 
| 203 | 
         
            +
             
     | 
| 204 | 
         
            +
                    self.cfg = self.datasets[0].cfg
         
     | 
| 205 | 
         
            +
             
     | 
| 206 | 
         
            +
                    self.metadata = []
         
     | 
| 207 | 
         
            +
             
     | 
| 208 | 
         
            +
                    # Merge metadata
         
     | 
| 209 | 
         
            +
                    for dataset in self.datasets:
         
     | 
| 210 | 
         
            +
                        self.metadata += dataset.metadata
         
     | 
| 211 | 
         
            +
             
     | 
| 212 | 
         
            +
                    # Merge random inference features
         
     | 
| 213 | 
         
            +
                    if full_audio_inference:
         
     | 
| 214 | 
         
            +
                        self.eval_audios = []
         
     | 
| 215 | 
         
            +
                        self.eval_dataset_names = []
         
     | 
| 216 | 
         
            +
                        if self.cfg.preprocess.use_mel:
         
     | 
| 217 | 
         
            +
                            self.eval_mels = []
         
     | 
| 218 | 
         
            +
                        if self.cfg.preprocess.use_frame_pitch:
         
     | 
| 219 | 
         
            +
                            self.eval_pitchs = []
         
     | 
| 220 | 
         
            +
                        for dataset in self.datasets:
         
     | 
| 221 | 
         
            +
                            self.eval_audios.append(dataset.eval_audio)
         
     | 
| 222 | 
         
            +
                            self.eval_dataset_names.append(dataset.get_dataset_name())
         
     | 
| 223 | 
         
            +
                            if self.cfg.preprocess.use_mel:
         
     | 
| 224 | 
         
            +
                                self.eval_mels.append(dataset.eval_mel)
         
     | 
| 225 | 
         
            +
                            if self.cfg.preprocess.use_frame_pitch:
         
     | 
| 226 | 
         
            +
                                self.eval_pitchs.append(dataset.eval_pitch)
         
     | 
| 227 | 
         
            +
             
     | 
| 228 | 
         
            +
             
     | 
| 229 | 
         
            +
            class CodecCollator(object):
         
     | 
| 230 | 
         
            +
                """Zero-pads model inputs and targets based on number of frames per step"""
         
     | 
| 231 | 
         
            +
             
     | 
| 232 | 
         
            +
                def __init__(self, cfg):
         
     | 
| 233 | 
         
            +
                    self.cfg = cfg
         
     | 
| 234 | 
         
            +
             
     | 
| 235 | 
         
            +
                def __call__(self, batch):
         
     | 
| 236 | 
         
            +
                    packed_batch_features = dict()
         
     | 
| 237 | 
         
            +
             
     | 
| 238 | 
         
            +
                    # mel: [b, n_mels, frame]
         
     | 
| 239 | 
         
            +
                    # frame_pitch: [b, frame]
         
     | 
| 240 | 
         
            +
                    # audios: [b, frame * hop_size]
         
     | 
| 241 | 
         
            +
             
     | 
| 242 | 
         
            +
                    for key in batch[0].keys():
         
     | 
| 243 | 
         
            +
                        if key == "target_len":
         
     | 
| 244 | 
         
            +
                            packed_batch_features["target_len"] = torch.LongTensor(
         
     | 
| 245 | 
         
            +
                                [b["target_len"] for b in batch]
         
     | 
| 246 | 
         
            +
                            )
         
     | 
| 247 | 
         
            +
                            masks = [
         
     | 
| 248 | 
         
            +
                                torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch
         
     | 
| 249 | 
         
            +
                            ]
         
     | 
| 250 | 
         
            +
                            packed_batch_features["mask"] = pad_sequence(
         
     | 
| 251 | 
         
            +
                                masks, batch_first=True, padding_value=0
         
     | 
| 252 | 
         
            +
                            )
         
     | 
| 253 | 
         
            +
                        elif key == "mel":
         
     | 
| 254 | 
         
            +
                            values = [torch.from_numpy(b[key]).T for b in batch]
         
     | 
| 255 | 
         
            +
                            packed_batch_features[key] = pad_sequence(
         
     | 
| 256 | 
         
            +
                                values, batch_first=True, padding_value=0
         
     | 
| 257 | 
         
            +
                            )
         
     | 
| 258 | 
         
            +
                        else:
         
     | 
| 259 | 
         
            +
                            values = [torch.from_numpy(b[key]) for b in batch]
         
     | 
| 260 | 
         
            +
                            packed_batch_features[key] = pad_sequence(
         
     | 
| 261 | 
         
            +
                                values, batch_first=True, padding_value=0
         
     | 
| 262 | 
         
            +
                            )
         
     | 
| 263 | 
         
            +
             
     | 
| 264 | 
         
            +
                    return packed_batch_features
         
     | 
    	
        models/codec/codec_inference.py
    ADDED
    
    | 
         @@ -0,0 +1,515 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2023 Amphion.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import os
         
     | 
| 7 | 
         
            +
            import torch
         
     | 
| 8 | 
         
            +
            import json
         
     | 
| 9 | 
         
            +
            import json5
         
     | 
| 10 | 
         
            +
            import time
         
     | 
| 11 | 
         
            +
            import accelerate
         
     | 
| 12 | 
         
            +
            import random
         
     | 
| 13 | 
         
            +
            import numpy as np
         
     | 
| 14 | 
         
            +
            import shutil
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            from pathlib import Path
         
     | 
| 17 | 
         
            +
            from tqdm import tqdm
         
     | 
| 18 | 
         
            +
            from glob import glob
         
     | 
| 19 | 
         
            +
            from accelerate.logging import get_logger
         
     | 
| 20 | 
         
            +
            from torch.utils.data import DataLoader
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            from models.vocoders.vocoder_dataset import (
         
     | 
| 23 | 
         
            +
                VocoderDataset,
         
     | 
| 24 | 
         
            +
                VocoderCollator,
         
     | 
| 25 | 
         
            +
                VocoderConcatDataset,
         
     | 
| 26 | 
         
            +
            )
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            from models.vocoders.gan.generator import bigvgan, hifigan, melgan, nsfhifigan, apnet
         
     | 
| 29 | 
         
            +
            from models.vocoders.flow.waveglow import waveglow
         
     | 
| 30 | 
         
            +
            from models.vocoders.diffusion.diffwave import diffwave
         
     | 
| 31 | 
         
            +
            from models.vocoders.autoregressive.wavenet import wavenet
         
     | 
| 32 | 
         
            +
            from models.vocoders.autoregressive.wavernn import wavernn
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
            from models.vocoders.gan import gan_vocoder_inference
         
     | 
| 35 | 
         
            +
            from models.vocoders.diffusion import diffusion_vocoder_inference
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
            from utils.io import save_audio
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
            _vocoders = {
         
     | 
| 40 | 
         
            +
                "diffwave": diffwave.DiffWave,
         
     | 
| 41 | 
         
            +
                "wavernn": wavernn.WaveRNN,
         
     | 
| 42 | 
         
            +
                "wavenet": wavenet.WaveNet,
         
     | 
| 43 | 
         
            +
                "waveglow": waveglow.WaveGlow,
         
     | 
| 44 | 
         
            +
                "nsfhifigan": nsfhifigan.NSFHiFiGAN,
         
     | 
| 45 | 
         
            +
                "bigvgan": bigvgan.BigVGAN,
         
     | 
| 46 | 
         
            +
                "hifigan": hifigan.HiFiGAN,
         
     | 
| 47 | 
         
            +
                "melgan": melgan.MelGAN,
         
     | 
| 48 | 
         
            +
                "apnet": apnet.APNet,
         
     | 
| 49 | 
         
            +
            }
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
            # Forward call for generalized Inferencor
         
     | 
| 52 | 
         
            +
            _vocoder_forward_funcs = {
         
     | 
| 53 | 
         
            +
                # "world": world_inference.synthesis_audios,
         
     | 
| 54 | 
         
            +
                # "wavernn": wavernn_inference.synthesis_audios,
         
     | 
| 55 | 
         
            +
                # "wavenet": wavenet_inference.synthesis_audios,
         
     | 
| 56 | 
         
            +
                "diffwave": diffusion_vocoder_inference.vocoder_inference,
         
     | 
| 57 | 
         
            +
                "nsfhifigan": gan_vocoder_inference.vocoder_inference,
         
     | 
| 58 | 
         
            +
                "bigvgan": gan_vocoder_inference.vocoder_inference,
         
     | 
| 59 | 
         
            +
                "melgan": gan_vocoder_inference.vocoder_inference,
         
     | 
| 60 | 
         
            +
                "hifigan": gan_vocoder_inference.vocoder_inference,
         
     | 
| 61 | 
         
            +
                "apnet": gan_vocoder_inference.vocoder_inference,
         
     | 
| 62 | 
         
            +
            }
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
            # APIs for other tasks. e.g. SVC, TTS, TTA...
         
     | 
| 65 | 
         
            +
            _vocoder_infer_funcs = {
         
     | 
| 66 | 
         
            +
                # "world": world_inference.synthesis_audios,
         
     | 
| 67 | 
         
            +
                # "wavernn": wavernn_inference.synthesis_audios,
         
     | 
| 68 | 
         
            +
                # "wavenet": wavenet_inference.synthesis_audios,
         
     | 
| 69 | 
         
            +
                "diffwave": diffusion_vocoder_inference.synthesis_audios,
         
     | 
| 70 | 
         
            +
                "nsfhifigan": gan_vocoder_inference.synthesis_audios,
         
     | 
| 71 | 
         
            +
                "bigvgan": gan_vocoder_inference.synthesis_audios,
         
     | 
| 72 | 
         
            +
                "melgan": gan_vocoder_inference.synthesis_audios,
         
     | 
| 73 | 
         
            +
                "hifigan": gan_vocoder_inference.synthesis_audios,
         
     | 
| 74 | 
         
            +
                "apnet": gan_vocoder_inference.synthesis_audios,
         
     | 
| 75 | 
         
            +
            }
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
            class VocoderInference(object):
         
     | 
| 79 | 
         
            +
                def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
         
     | 
| 80 | 
         
            +
                    super().__init__()
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                    start = time.monotonic_ns()
         
     | 
| 83 | 
         
            +
                    self.args = args
         
     | 
| 84 | 
         
            +
                    self.cfg = cfg
         
     | 
| 85 | 
         
            +
                    self.infer_type = infer_type
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
                    # Init accelerator
         
     | 
| 88 | 
         
            +
                    self.accelerator = accelerate.Accelerator()
         
     | 
| 89 | 
         
            +
                    self.accelerator.wait_for_everyone()
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                    # Get logger
         
     | 
| 92 | 
         
            +
                    with self.accelerator.main_process_first():
         
     | 
| 93 | 
         
            +
                        self.logger = get_logger("inference", log_level=args.log_level)
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
                    # Log some info
         
     | 
| 96 | 
         
            +
                    self.logger.info("=" * 56)
         
     | 
| 97 | 
         
            +
                    self.logger.info("||\t\t" + "New inference process started." + "\t\t||")
         
     | 
| 98 | 
         
            +
                    self.logger.info("=" * 56)
         
     | 
| 99 | 
         
            +
                    self.logger.info("\n")
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
                    self.vocoder_dir = args.vocoder_dir
         
     | 
| 102 | 
         
            +
                    self.logger.debug(f"Vocoder dir: {args.vocoder_dir}")
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
                    os.makedirs(args.output_dir, exist_ok=True)
         
     | 
| 105 | 
         
            +
                    if os.path.exists(os.path.join(args.output_dir, "pred")):
         
     | 
| 106 | 
         
            +
                        shutil.rmtree(os.path.join(args.output_dir, "pred"))
         
     | 
| 107 | 
         
            +
                    if os.path.exists(os.path.join(args.output_dir, "gt")):
         
     | 
| 108 | 
         
            +
                        shutil.rmtree(os.path.join(args.output_dir, "gt"))
         
     | 
| 109 | 
         
            +
                    os.makedirs(os.path.join(args.output_dir, "pred"), exist_ok=True)
         
     | 
| 110 | 
         
            +
                    os.makedirs(os.path.join(args.output_dir, "gt"), exist_ok=True)
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                    # Set random seed
         
     | 
| 113 | 
         
            +
                    with self.accelerator.main_process_first():
         
     | 
| 114 | 
         
            +
                        start = time.monotonic_ns()
         
     | 
| 115 | 
         
            +
                        self._set_random_seed(self.cfg.train.random_seed)
         
     | 
| 116 | 
         
            +
                        end = time.monotonic_ns()
         
     | 
| 117 | 
         
            +
                        self.logger.debug(
         
     | 
| 118 | 
         
            +
                            f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
         
     | 
| 119 | 
         
            +
                        )
         
     | 
| 120 | 
         
            +
                        self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
                    # Setup inference mode
         
     | 
| 123 | 
         
            +
                    if self.infer_type == "infer_from_dataset":
         
     | 
| 124 | 
         
            +
                        self.cfg.dataset = self.args.infer_datasets
         
     | 
| 125 | 
         
            +
                    elif self.infer_type == "infer_from_feature":
         
     | 
| 126 | 
         
            +
                        self._build_tmp_dataset_from_feature()
         
     | 
| 127 | 
         
            +
                        self.cfg.dataset = ["tmp"]
         
     | 
| 128 | 
         
            +
                    elif self.infer_type == "infer_from_audio":
         
     | 
| 129 | 
         
            +
                        self._build_tmp_dataset_from_audio()
         
     | 
| 130 | 
         
            +
                        self.cfg.dataset = ["tmp"]
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
                    # Setup data loader
         
     | 
| 133 | 
         
            +
                    with self.accelerator.main_process_first():
         
     | 
| 134 | 
         
            +
                        self.logger.info("Building dataset...")
         
     | 
| 135 | 
         
            +
                        start = time.monotonic_ns()
         
     | 
| 136 | 
         
            +
                        self.test_dataloader = self._build_dataloader()
         
     | 
| 137 | 
         
            +
                        end = time.monotonic_ns()
         
     | 
| 138 | 
         
            +
                        self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms")
         
     | 
| 139 | 
         
            +
             
     | 
| 140 | 
         
            +
                    # Build model
         
     | 
| 141 | 
         
            +
                    with self.accelerator.main_process_first():
         
     | 
| 142 | 
         
            +
                        self.logger.info("Building model...")
         
     | 
| 143 | 
         
            +
                        start = time.monotonic_ns()
         
     | 
| 144 | 
         
            +
                        self.model = self._build_model()
         
     | 
| 145 | 
         
            +
                        end = time.monotonic_ns()
         
     | 
| 146 | 
         
            +
                        self.logger.info(f"Building model done in {(end - start) / 1e6:.3f}ms")
         
     | 
| 147 | 
         
            +
             
     | 
| 148 | 
         
            +
                    # Init with accelerate
         
     | 
| 149 | 
         
            +
                    self.logger.info("Initializing accelerate...")
         
     | 
| 150 | 
         
            +
                    start = time.monotonic_ns()
         
     | 
| 151 | 
         
            +
                    self.accelerator = accelerate.Accelerator()
         
     | 
| 152 | 
         
            +
                    (self.model, self.test_dataloader) = self.accelerator.prepare(
         
     | 
| 153 | 
         
            +
                        self.model, self.test_dataloader
         
     | 
| 154 | 
         
            +
                    )
         
     | 
| 155 | 
         
            +
                    end = time.monotonic_ns()
         
     | 
| 156 | 
         
            +
                    self.accelerator.wait_for_everyone()
         
     | 
| 157 | 
         
            +
                    self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.3f}ms")
         
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
                    with self.accelerator.main_process_first():
         
     | 
| 160 | 
         
            +
                        self.logger.info("Loading checkpoint...")
         
     | 
| 161 | 
         
            +
                        start = time.monotonic_ns()
         
     | 
| 162 | 
         
            +
                        if os.path.isdir(args.vocoder_dir):
         
     | 
| 163 | 
         
            +
                            if os.path.isdir(os.path.join(args.vocoder_dir, "checkpoint")):
         
     | 
| 164 | 
         
            +
                                self._load_model(os.path.join(args.vocoder_dir, "checkpoint"))
         
     | 
| 165 | 
         
            +
                            else:
         
     | 
| 166 | 
         
            +
                                self._load_model(os.path.join(args.vocoder_dir))
         
     | 
| 167 | 
         
            +
                        else:
         
     | 
| 168 | 
         
            +
                            self._load_model(os.path.join(args.vocoder_dir))
         
     | 
| 169 | 
         
            +
                        end = time.monotonic_ns()
         
     | 
| 170 | 
         
            +
                        self.logger.info(f"Loading checkpoint done in {(end - start) / 1e6:.3f}ms")
         
     | 
| 171 | 
         
            +
             
     | 
| 172 | 
         
            +
                    self.model.eval()
         
     | 
| 173 | 
         
            +
                    self.accelerator.wait_for_everyone()
         
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
            +
                def _build_tmp_dataset_from_feature(self):
         
     | 
| 176 | 
         
            +
                    if os.path.exists(os.path.join(self.cfg.preprocess.processed_dir, "tmp")):
         
     | 
| 177 | 
         
            +
                        shutil.rmtree(os.path.join(self.cfg.preprocess.processed_dir, "tmp"))
         
     | 
| 178 | 
         
            +
             
     | 
| 179 | 
         
            +
                    utts = []
         
     | 
| 180 | 
         
            +
                    mels = glob(os.path.join(self.args.feature_folder, "mels", "*.npy"))
         
     | 
| 181 | 
         
            +
                    for i, mel in enumerate(mels):
         
     | 
| 182 | 
         
            +
                        uid = mel.split("/")[-1].split(".")[0]
         
     | 
| 183 | 
         
            +
                        utt = {"Dataset": "tmp", "Uid": uid, "index": i}
         
     | 
| 184 | 
         
            +
                        utts.append(utt)
         
     | 
| 185 | 
         
            +
             
     | 
| 186 | 
         
            +
                    os.makedirs(os.path.join(self.cfg.preprocess.processed_dir, "tmp"))
         
     | 
| 187 | 
         
            +
                    with open(
         
     | 
| 188 | 
         
            +
                        os.path.join(self.cfg.preprocess.processed_dir, "tmp", "test.json"), "w"
         
     | 
| 189 | 
         
            +
                    ) as f:
         
     | 
| 190 | 
         
            +
                        json.dump(utts, f)
         
     | 
| 191 | 
         
            +
             
     | 
| 192 | 
         
            +
                    meta_info = {"dataset": "tmp", "test": {"size": len(utts)}}
         
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
                    with open(
         
     | 
| 195 | 
         
            +
                        os.path.join(self.cfg.preprocess.processed_dir, "tmp", "meta_info.json"),
         
     | 
| 196 | 
         
            +
                        "w",
         
     | 
| 197 | 
         
            +
                    ) as f:
         
     | 
| 198 | 
         
            +
                        json.dump(meta_info, f)
         
     | 
| 199 | 
         
            +
             
     | 
| 200 | 
         
            +
                    features = glob(os.path.join(self.args.feature_folder, "*"))
         
     | 
| 201 | 
         
            +
                    for feature in features:
         
     | 
| 202 | 
         
            +
                        feature_name = feature.split("/")[-1]
         
     | 
| 203 | 
         
            +
                        if os.path.isfile(feature):
         
     | 
| 204 | 
         
            +
                            continue
         
     | 
| 205 | 
         
            +
                        shutil.copytree(
         
     | 
| 206 | 
         
            +
                            os.path.join(self.args.feature_folder, feature_name),
         
     | 
| 207 | 
         
            +
                            os.path.join(self.cfg.preprocess.processed_dir, "tmp", feature_name),
         
     | 
| 208 | 
         
            +
                        )
         
     | 
| 209 | 
         
            +
             
     | 
| 210 | 
         
            +
                def _build_tmp_dataset_from_audio(self):
         
     | 
| 211 | 
         
            +
                    if os.path.exists(os.path.join(self.cfg.preprocess.processed_dir, "tmp")):
         
     | 
| 212 | 
         
            +
                        shutil.rmtree(os.path.join(self.cfg.preprocess.processed_dir, "tmp"))
         
     | 
| 213 | 
         
            +
             
     | 
| 214 | 
         
            +
                    utts = []
         
     | 
| 215 | 
         
            +
                    audios = glob(os.path.join(self.args.audio_folder, "*"))
         
     | 
| 216 | 
         
            +
                    for i, audio in enumerate(audios):
         
     | 
| 217 | 
         
            +
                        uid = audio.split("/")[-1].split(".")[0]
         
     | 
| 218 | 
         
            +
                        utt = {"Dataset": "tmp", "Uid": uid, "index": i, "Path": audio}
         
     | 
| 219 | 
         
            +
                        utts.append(utt)
         
     | 
| 220 | 
         
            +
             
     | 
| 221 | 
         
            +
                    os.makedirs(os.path.join(self.cfg.preprocess.processed_dir, "tmp"))
         
     | 
| 222 | 
         
            +
                    with open(
         
     | 
| 223 | 
         
            +
                        os.path.join(self.cfg.preprocess.processed_dir, "tmp", "test.json"), "w"
         
     | 
| 224 | 
         
            +
                    ) as f:
         
     | 
| 225 | 
         
            +
                        json.dump(utts, f)
         
     | 
| 226 | 
         
            +
             
     | 
| 227 | 
         
            +
                    meta_info = {"dataset": "tmp", "test": {"size": len(utts)}}
         
     | 
| 228 | 
         
            +
             
     | 
| 229 | 
         
            +
                    with open(
         
     | 
| 230 | 
         
            +
                        os.path.join(self.cfg.preprocess.processed_dir, "tmp", "meta_info.json"),
         
     | 
| 231 | 
         
            +
                        "w",
         
     | 
| 232 | 
         
            +
                    ) as f:
         
     | 
| 233 | 
         
            +
                        json.dump(meta_info, f)
         
     | 
| 234 | 
         
            +
             
     | 
| 235 | 
         
            +
                    from processors import acoustic_extractor
         
     | 
| 236 | 
         
            +
             
     | 
| 237 | 
         
            +
                    acoustic_extractor.extract_utt_acoustic_features_serial(
         
     | 
| 238 | 
         
            +
                        utts, os.path.join(self.cfg.preprocess.processed_dir, "tmp"), self.cfg
         
     | 
| 239 | 
         
            +
                    )
         
     | 
| 240 | 
         
            +
             
     | 
| 241 | 
         
            +
                def _build_test_dataset(self):
         
     | 
| 242 | 
         
            +
                    return VocoderDataset, VocoderCollator
         
     | 
| 243 | 
         
            +
             
     | 
| 244 | 
         
            +
                def _build_model(self):
         
     | 
| 245 | 
         
            +
                    model = _vocoders[self.cfg.model.generator](self.cfg)
         
     | 
| 246 | 
         
            +
                    return model
         
     | 
| 247 | 
         
            +
             
     | 
| 248 | 
         
            +
                def _build_dataloader(self):
         
     | 
| 249 | 
         
            +
                    """Build dataloader which merges a series of datasets."""
         
     | 
| 250 | 
         
            +
                    Dataset, Collator = self._build_test_dataset()
         
     | 
| 251 | 
         
            +
             
     | 
| 252 | 
         
            +
                    datasets_list = []
         
     | 
| 253 | 
         
            +
                    for dataset in self.cfg.dataset:
         
     | 
| 254 | 
         
            +
                        subdataset = Dataset(self.cfg, dataset, is_valid=True)
         
     | 
| 255 | 
         
            +
                        datasets_list.append(subdataset)
         
     | 
| 256 | 
         
            +
                    test_dataset = VocoderConcatDataset(datasets_list, full_audio_inference=False)
         
     | 
| 257 | 
         
            +
                    test_collate = Collator(self.cfg)
         
     | 
| 258 | 
         
            +
                    test_batch_size = min(self.cfg.inference.batch_size, len(test_dataset))
         
     | 
| 259 | 
         
            +
                    test_dataloader = DataLoader(
         
     | 
| 260 | 
         
            +
                        test_dataset,
         
     | 
| 261 | 
         
            +
                        collate_fn=test_collate,
         
     | 
| 262 | 
         
            +
                        num_workers=1,
         
     | 
| 263 | 
         
            +
                        batch_size=test_batch_size,
         
     | 
| 264 | 
         
            +
                        shuffle=False,
         
     | 
| 265 | 
         
            +
                    )
         
     | 
| 266 | 
         
            +
                    self.test_batch_size = test_batch_size
         
     | 
| 267 | 
         
            +
                    self.test_dataset = test_dataset
         
     | 
| 268 | 
         
            +
                    return test_dataloader
         
     | 
| 269 | 
         
            +
             
     | 
| 270 | 
         
            +
                def _load_model(self, checkpoint_dir, from_multi_gpu=False):
         
     | 
| 271 | 
         
            +
                    """Load model from checkpoint. If a folder is given, it will
         
     | 
| 272 | 
         
            +
                    load the latest checkpoint in checkpoint_dir. If a path is given
         
     | 
| 273 | 
         
            +
                    it will load the checkpoint specified by checkpoint_path.
         
     | 
| 274 | 
         
            +
                    **Only use this method after** ``accelerator.prepare()``.
         
     | 
| 275 | 
         
            +
                    """
         
     | 
| 276 | 
         
            +
                    if os.path.isdir(checkpoint_dir):
         
     | 
| 277 | 
         
            +
                        if "epoch" in checkpoint_dir and "step" in checkpoint_dir:
         
     | 
| 278 | 
         
            +
                            checkpoint_path = checkpoint_dir
         
     | 
| 279 | 
         
            +
                        else:
         
     | 
| 280 | 
         
            +
                            # Load the latest accelerator state dicts
         
     | 
| 281 | 
         
            +
                            ls = [
         
     | 
| 282 | 
         
            +
                                str(i)
         
     | 
| 283 | 
         
            +
                                for i in Path(checkpoint_dir).glob("*")
         
     | 
| 284 | 
         
            +
                                if not "audio" in str(i)
         
     | 
| 285 | 
         
            +
                            ]
         
     | 
| 286 | 
         
            +
                            ls.sort(
         
     | 
| 287 | 
         
            +
                                key=lambda x: int(x.split("/")[-1].split("_")[0].split("-")[-1]),
         
     | 
| 288 | 
         
            +
                                reverse=True,
         
     | 
| 289 | 
         
            +
                            )
         
     | 
| 290 | 
         
            +
                            checkpoint_path = ls[0]
         
     | 
| 291 | 
         
            +
                        accelerate.load_checkpoint_and_dispatch(
         
     | 
| 292 | 
         
            +
                            self.accelerator.unwrap_model(self.model),
         
     | 
| 293 | 
         
            +
                            os.path.join(checkpoint_path, "pytorch_model.bin"),
         
     | 
| 294 | 
         
            +
                        )
         
     | 
| 295 | 
         
            +
                        return str(checkpoint_path)
         
     | 
| 296 | 
         
            +
                    else:
         
     | 
| 297 | 
         
            +
                        # Load old .pt checkpoints
         
     | 
| 298 | 
         
            +
                        if self.cfg.model.generator in [
         
     | 
| 299 | 
         
            +
                            "bigvgan",
         
     | 
| 300 | 
         
            +
                            "hifigan",
         
     | 
| 301 | 
         
            +
                            "melgan",
         
     | 
| 302 | 
         
            +
                            "nsfhifigan",
         
     | 
| 303 | 
         
            +
                        ]:
         
     | 
| 304 | 
         
            +
                            ckpt = torch.load(
         
     | 
| 305 | 
         
            +
                                checkpoint_dir,
         
     | 
| 306 | 
         
            +
                                map_location=(
         
     | 
| 307 | 
         
            +
                                    torch.device("cuda")
         
     | 
| 308 | 
         
            +
                                    if torch.cuda.is_available()
         
     | 
| 309 | 
         
            +
                                    else torch.device("cpu")
         
     | 
| 310 | 
         
            +
                                ),
         
     | 
| 311 | 
         
            +
                            )
         
     | 
| 312 | 
         
            +
                            if from_multi_gpu:
         
     | 
| 313 | 
         
            +
                                pretrained_generator_dict = ckpt["generator_state_dict"]
         
     | 
| 314 | 
         
            +
                                generator_dict = self.model.state_dict()
         
     | 
| 315 | 
         
            +
             
     | 
| 316 | 
         
            +
                                new_generator_dict = {
         
     | 
| 317 | 
         
            +
                                    k.split("module.")[-1]: v
         
     | 
| 318 | 
         
            +
                                    for k, v in pretrained_generator_dict.items()
         
     | 
| 319 | 
         
            +
                                    if (
         
     | 
| 320 | 
         
            +
                                        k.split("module.")[-1] in generator_dict
         
     | 
| 321 | 
         
            +
                                        and v.shape == generator_dict[k.split("module.")[-1]].shape
         
     | 
| 322 | 
         
            +
                                    )
         
     | 
| 323 | 
         
            +
                                }
         
     | 
| 324 | 
         
            +
             
     | 
| 325 | 
         
            +
                                generator_dict.update(new_generator_dict)
         
     | 
| 326 | 
         
            +
             
     | 
| 327 | 
         
            +
                                self.model.load_state_dict(generator_dict)
         
     | 
| 328 | 
         
            +
                            else:
         
     | 
| 329 | 
         
            +
                                self.model.load_state_dict(ckpt["generator_state_dict"])
         
     | 
| 330 | 
         
            +
                        else:
         
     | 
| 331 | 
         
            +
                            self.model.load_state_dict(torch.load(checkpoint_dir)["state_dict"])
         
     | 
| 332 | 
         
            +
                        return str(checkpoint_dir)
         
     | 
| 333 | 
         
            +
             
     | 
| 334 | 
         
            +
                def inference(self):
         
     | 
| 335 | 
         
            +
                    """Inference via batches"""
         
     | 
| 336 | 
         
            +
                    for i, batch in tqdm(enumerate(self.test_dataloader)):
         
     | 
| 337 | 
         
            +
                        if self.cfg.preprocess.use_frame_pitch:
         
     | 
| 338 | 
         
            +
                            audio_pred = _vocoder_forward_funcs[self.cfg.model.generator](
         
     | 
| 339 | 
         
            +
                                self.cfg,
         
     | 
| 340 | 
         
            +
                                self.model,
         
     | 
| 341 | 
         
            +
                                batch["mel"].transpose(-1, -2),
         
     | 
| 342 | 
         
            +
                                f0s=batch["frame_pitch"].float(),
         
     | 
| 343 | 
         
            +
                                device=next(self.model.parameters()).device,
         
     | 
| 344 | 
         
            +
                            )
         
     | 
| 345 | 
         
            +
                        else:
         
     | 
| 346 | 
         
            +
                            audio_pred = _vocoder_forward_funcs[self.cfg.model.generator](
         
     | 
| 347 | 
         
            +
                                self.cfg,
         
     | 
| 348 | 
         
            +
                                self.model,
         
     | 
| 349 | 
         
            +
                                batch["mel"].transpose(-1, -2),
         
     | 
| 350 | 
         
            +
                                device=next(self.model.parameters()).device,
         
     | 
| 351 | 
         
            +
                            )
         
     | 
| 352 | 
         
            +
                        audio_ls = audio_pred.chunk(self.test_batch_size)
         
     | 
| 353 | 
         
            +
                        audio_gt_ls = batch["audio"].cpu().chunk(self.test_batch_size)
         
     | 
| 354 | 
         
            +
                        length_ls = batch["target_len"].cpu().chunk(self.test_batch_size)
         
     | 
| 355 | 
         
            +
                        j = 0
         
     | 
| 356 | 
         
            +
                        for it, it_gt, l in zip(audio_ls, audio_gt_ls, length_ls):
         
     | 
| 357 | 
         
            +
                            l = l.item()
         
     | 
| 358 | 
         
            +
                            it = it.squeeze(0).squeeze(0)[: l * self.cfg.preprocess.hop_size]
         
     | 
| 359 | 
         
            +
                            it_gt = it_gt.squeeze(0)[: l * self.cfg.preprocess.hop_size]
         
     | 
| 360 | 
         
            +
                            uid = self.test_dataset.metadata[i * self.test_batch_size + j]["Uid"]
         
     | 
| 361 | 
         
            +
                            save_audio(
         
     | 
| 362 | 
         
            +
                                os.path.join(self.args.output_dir, "pred", "{}.wav").format(uid),
         
     | 
| 363 | 
         
            +
                                it,
         
     | 
| 364 | 
         
            +
                                self.cfg.preprocess.sample_rate,
         
     | 
| 365 | 
         
            +
                            )
         
     | 
| 366 | 
         
            +
                            save_audio(
         
     | 
| 367 | 
         
            +
                                os.path.join(self.args.output_dir, "gt", "{}.wav").format(uid),
         
     | 
| 368 | 
         
            +
                                it_gt,
         
     | 
| 369 | 
         
            +
                                self.cfg.preprocess.sample_rate,
         
     | 
| 370 | 
         
            +
                            )
         
     | 
| 371 | 
         
            +
                            j += 1
         
     | 
| 372 | 
         
            +
             
     | 
| 373 | 
         
            +
                    if os.path.exists(os.path.join(self.cfg.preprocess.processed_dir, "tmp")):
         
     | 
| 374 | 
         
            +
                        shutil.rmtree(os.path.join(self.cfg.preprocess.processed_dir, "tmp"))
         
     | 
| 375 | 
         
            +
             
     | 
| 376 | 
         
            +
                def _set_random_seed(self, seed):
         
     | 
| 377 | 
         
            +
                    """Set random seed for all possible random modules."""
         
     | 
| 378 | 
         
            +
                    random.seed(seed)
         
     | 
| 379 | 
         
            +
                    np.random.seed(seed)
         
     | 
| 380 | 
         
            +
                    torch.random.manual_seed(seed)
         
     | 
| 381 | 
         
            +
             
     | 
| 382 | 
         
            +
                def _count_parameters(self, model):
         
     | 
| 383 | 
         
            +
                    return sum(p.numel() for p in model.parameters())
         
     | 
| 384 | 
         
            +
             
     | 
| 385 | 
         
            +
                def _dump_cfg(self, path):
         
     | 
| 386 | 
         
            +
                    os.makedirs(os.path.dirname(path), exist_ok=True)
         
     | 
| 387 | 
         
            +
                    json5.dump(
         
     | 
| 388 | 
         
            +
                        self.cfg,
         
     | 
| 389 | 
         
            +
                        open(path, "w"),
         
     | 
| 390 | 
         
            +
                        indent=4,
         
     | 
| 391 | 
         
            +
                        sort_keys=True,
         
     | 
| 392 | 
         
            +
                        ensure_ascii=False,
         
     | 
| 393 | 
         
            +
                        quote_keys=True,
         
     | 
| 394 | 
         
            +
                    )
         
     | 
| 395 | 
         
            +
             
     | 
| 396 | 
         
            +
             
     | 
| 397 | 
         
            +
            def load_nnvocoder(
         
     | 
| 398 | 
         
            +
                cfg,
         
     | 
| 399 | 
         
            +
                vocoder_name,
         
     | 
| 400 | 
         
            +
                weights_file,
         
     | 
| 401 | 
         
            +
                from_multi_gpu=False,
         
     | 
| 402 | 
         
            +
            ):
         
     | 
| 403 | 
         
            +
                """Load the specified vocoder.
         
     | 
| 404 | 
         
            +
                cfg: the vocoder config filer.
         
     | 
| 405 | 
         
            +
                weights_file: a folder or a .pt path.
         
     | 
| 406 | 
         
            +
                from_multi_gpu: automatically remove the "module" string in state dicts if "True".
         
     | 
| 407 | 
         
            +
                """
         
     | 
| 408 | 
         
            +
                print("Loading Vocoder from Weights file: {}".format(weights_file))
         
     | 
| 409 | 
         
            +
             
     | 
| 410 | 
         
            +
                # Build model
         
     | 
| 411 | 
         
            +
                model = _vocoders[vocoder_name](cfg)
         
     | 
| 412 | 
         
            +
                if not os.path.isdir(weights_file):
         
     | 
| 413 | 
         
            +
                    # Load from .pt file
         
     | 
| 414 | 
         
            +
                    if vocoder_name in ["bigvgan", "hifigan", "melgan", "nsfhifigan"]:
         
     | 
| 415 | 
         
            +
                        ckpt = torch.load(
         
     | 
| 416 | 
         
            +
                            weights_file,
         
     | 
| 417 | 
         
            +
                            map_location=(
         
     | 
| 418 | 
         
            +
                                torch.device("cuda")
         
     | 
| 419 | 
         
            +
                                if torch.cuda.is_available()
         
     | 
| 420 | 
         
            +
                                else torch.device("cpu")
         
     | 
| 421 | 
         
            +
                            ),
         
     | 
| 422 | 
         
            +
                        )
         
     | 
| 423 | 
         
            +
                        if from_multi_gpu:
         
     | 
| 424 | 
         
            +
                            pretrained_generator_dict = ckpt["generator_state_dict"]
         
     | 
| 425 | 
         
            +
                            generator_dict = model.state_dict()
         
     | 
| 426 | 
         
            +
             
     | 
| 427 | 
         
            +
                            new_generator_dict = {
         
     | 
| 428 | 
         
            +
                                k.split("module.")[-1]: v
         
     | 
| 429 | 
         
            +
                                for k, v in pretrained_generator_dict.items()
         
     | 
| 430 | 
         
            +
                                if (
         
     | 
| 431 | 
         
            +
                                    k.split("module.")[-1] in generator_dict
         
     | 
| 432 | 
         
            +
                                    and v.shape == generator_dict[k.split("module.")[-1]].shape
         
     | 
| 433 | 
         
            +
                                )
         
     | 
| 434 | 
         
            +
                            }
         
     | 
| 435 | 
         
            +
             
     | 
| 436 | 
         
            +
                            generator_dict.update(new_generator_dict)
         
     | 
| 437 | 
         
            +
             
     | 
| 438 | 
         
            +
                            model.load_state_dict(generator_dict)
         
     | 
| 439 | 
         
            +
                        else:
         
     | 
| 440 | 
         
            +
                            model.load_state_dict(ckpt["generator_state_dict"])
         
     | 
| 441 | 
         
            +
                    else:
         
     | 
| 442 | 
         
            +
                        model.load_state_dict(torch.load(weights_file)["state_dict"])
         
     | 
| 443 | 
         
            +
                else:
         
     | 
| 444 | 
         
            +
                    # Load from accelerator state dict
         
     | 
| 445 | 
         
            +
                    weights_file = os.path.join(weights_file, "checkpoint")
         
     | 
| 446 | 
         
            +
                    ls = [str(i) for i in Path(weights_file).glob("*") if not "audio" in str(i)]
         
     | 
| 447 | 
         
            +
                    ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True)
         
     | 
| 448 | 
         
            +
                    checkpoint_path = ls[0]
         
     | 
| 449 | 
         
            +
                    accelerator = accelerate.Accelerator()
         
     | 
| 450 | 
         
            +
                    model = accelerator.prepare(model)
         
     | 
| 451 | 
         
            +
                    accelerator.load_state(checkpoint_path)
         
     | 
| 452 | 
         
            +
             
     | 
| 453 | 
         
            +
                if torch.cuda.is_available():
         
     | 
| 454 | 
         
            +
                    model = model.cuda()
         
     | 
| 455 | 
         
            +
             
     | 
| 456 | 
         
            +
                model = model.eval()
         
     | 
| 457 | 
         
            +
                return model
         
     | 
| 458 | 
         
            +
             
     | 
| 459 | 
         
            +
             
     | 
| 460 | 
         
            +
            def tensorize(data, device, n_samples):
         
     | 
| 461 | 
         
            +
                """
         
     | 
| 462 | 
         
            +
                data: a list of numpy array
         
     | 
| 463 | 
         
            +
                """
         
     | 
| 464 | 
         
            +
                assert type(data) == list
         
     | 
| 465 | 
         
            +
                if n_samples:
         
     | 
| 466 | 
         
            +
                    data = data[:n_samples]
         
     | 
| 467 | 
         
            +
                data = [torch.as_tensor(x, device=device) for x in data]
         
     | 
| 468 | 
         
            +
                return data
         
     | 
| 469 | 
         
            +
             
     | 
| 470 | 
         
            +
             
     | 
| 471 | 
         
            +
            def synthesis(
         
     | 
| 472 | 
         
            +
                cfg,
         
     | 
| 473 | 
         
            +
                vocoder_weight_file,
         
     | 
| 474 | 
         
            +
                n_samples,
         
     | 
| 475 | 
         
            +
                pred,
         
     | 
| 476 | 
         
            +
                f0s=None,
         
     | 
| 477 | 
         
            +
                batch_size=64,
         
     | 
| 478 | 
         
            +
                fast_inference=False,
         
     | 
| 479 | 
         
            +
            ):
         
     | 
| 480 | 
         
            +
                """Synthesis audios from a given vocoder and series of given features.
         
     | 
| 481 | 
         
            +
                cfg: vocoder config.
         
     | 
| 482 | 
         
            +
                vocoder_weight_file: a folder of accelerator state dict or a path to the .pt file.
         
     | 
| 483 | 
         
            +
                pred: a list of numpy arrays. [(seq_len1, acoustic_features_dim), (seq_len2, acoustic_features_dim), ...]
         
     | 
| 484 | 
         
            +
                """
         
     | 
| 485 | 
         
            +
             
     | 
| 486 | 
         
            +
                vocoder_name = cfg.model.generator
         
     | 
| 487 | 
         
            +
             
     | 
| 488 | 
         
            +
                print("Synthesis audios using {} vocoder...".format(vocoder_name))
         
     | 
| 489 | 
         
            +
             
     | 
| 490 | 
         
            +
                ###### TODO: World Vocoder Refactor ######
         
     | 
| 491 | 
         
            +
                # if vocoder_name == "world":
         
     | 
| 492 | 
         
            +
                #     world_inference.synthesis_audios(
         
     | 
| 493 | 
         
            +
                #         cfg, dataset_name, split, n_samples, pred, save_dir, tag
         
     | 
| 494 | 
         
            +
                #     )
         
     | 
| 495 | 
         
            +
                #     return
         
     | 
| 496 | 
         
            +
             
     | 
| 497 | 
         
            +
                # ====== Loading neural vocoder model ======
         
     | 
| 498 | 
         
            +
                vocoder = load_nnvocoder(
         
     | 
| 499 | 
         
            +
                    cfg, vocoder_name, weights_file=vocoder_weight_file, from_multi_gpu=True
         
     | 
| 500 | 
         
            +
                )
         
     | 
| 501 | 
         
            +
                device = next(vocoder.parameters()).device
         
     | 
| 502 | 
         
            +
             
     | 
| 503 | 
         
            +
                # ====== Inference for predicted acoustic features ======
         
     | 
| 504 | 
         
            +
                # pred: (frame_len, n_mels) -> (n_mels, frame_len)
         
     | 
| 505 | 
         
            +
                mels_pred = tensorize([p.T for p in pred], device, n_samples)
         
     | 
| 506 | 
         
            +
                print("For predicted mels, #sample = {}...".format(len(mels_pred)))
         
     | 
| 507 | 
         
            +
                audios_pred = _vocoder_infer_funcs[vocoder_name](
         
     | 
| 508 | 
         
            +
                    cfg,
         
     | 
| 509 | 
         
            +
                    vocoder,
         
     | 
| 510 | 
         
            +
                    mels_pred,
         
     | 
| 511 | 
         
            +
                    f0s=f0s,
         
     | 
| 512 | 
         
            +
                    batch_size=batch_size,
         
     | 
| 513 | 
         
            +
                    fast_inference=fast_inference,
         
     | 
| 514 | 
         
            +
                )
         
     | 
| 515 | 
         
            +
                return audios_pred
         
     | 
    	
        models/codec/codec_sampler.py
    ADDED
    
    | 
         @@ -0,0 +1,126 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2023 Amphion.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import math
         
     | 
| 7 | 
         
            +
            import random
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            from torch.utils.data import ConcatDataset, Dataset
         
     | 
| 10 | 
         
            +
            from torch.utils.data.sampler import (
         
     | 
| 11 | 
         
            +
                BatchSampler,
         
     | 
| 12 | 
         
            +
                RandomSampler,
         
     | 
| 13 | 
         
            +
                Sampler,
         
     | 
| 14 | 
         
            +
                SequentialSampler,
         
     | 
| 15 | 
         
            +
            )
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            class ScheduledSampler(Sampler):
         
     | 
| 19 | 
         
            +
                """A sampler that samples data from a given concat-dataset.
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
                Args:
         
     | 
| 22 | 
         
            +
                    concat_dataset (ConcatDataset): a concatenated dataset consisting of all datasets
         
     | 
| 23 | 
         
            +
                    batch_size (int): batch size
         
     | 
| 24 | 
         
            +
                    holistic_shuffle (bool): whether to shuffle the whole dataset or not
         
     | 
| 25 | 
         
            +
                    logger (logging.Logger): logger to print warning message
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                Usage:
         
     | 
| 28 | 
         
            +
                    For cfg.train.batch_size = 3, cfg.train.holistic_shuffle = False, cfg.train.drop_last = True:
         
     | 
| 29 | 
         
            +
                    >>> list(ScheduledSampler(ConcatDataset([0, 1, 2], [3, 4, 5], [6, 7, 8]])))
         
     | 
| 30 | 
         
            +
                    [3, 4, 5, 0, 1, 2, 6, 7, 8]
         
     | 
| 31 | 
         
            +
                """
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
                def __init__(
         
     | 
| 34 | 
         
            +
                    self, concat_dataset, batch_size, holistic_shuffle, logger=None, type="train"
         
     | 
| 35 | 
         
            +
                ):
         
     | 
| 36 | 
         
            +
                    if not isinstance(concat_dataset, ConcatDataset):
         
     | 
| 37 | 
         
            +
                        raise ValueError(
         
     | 
| 38 | 
         
            +
                            "concat_dataset must be an instance of ConcatDataset, but got {}".format(
         
     | 
| 39 | 
         
            +
                                type(concat_dataset)
         
     | 
| 40 | 
         
            +
                            )
         
     | 
| 41 | 
         
            +
                        )
         
     | 
| 42 | 
         
            +
                    if not isinstance(batch_size, int):
         
     | 
| 43 | 
         
            +
                        raise ValueError(
         
     | 
| 44 | 
         
            +
                            "batch_size must be an integer, but got {}".format(type(batch_size))
         
     | 
| 45 | 
         
            +
                        )
         
     | 
| 46 | 
         
            +
                    if not isinstance(holistic_shuffle, bool):
         
     | 
| 47 | 
         
            +
                        raise ValueError(
         
     | 
| 48 | 
         
            +
                            "holistic_shuffle must be a boolean, but got {}".format(
         
     | 
| 49 | 
         
            +
                                type(holistic_shuffle)
         
     | 
| 50 | 
         
            +
                            )
         
     | 
| 51 | 
         
            +
                        )
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                    self.concat_dataset = concat_dataset
         
     | 
| 54 | 
         
            +
                    self.batch_size = batch_size
         
     | 
| 55 | 
         
            +
                    self.holistic_shuffle = holistic_shuffle
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                    affected_dataset_name = []
         
     | 
| 58 | 
         
            +
                    affected_dataset_len = []
         
     | 
| 59 | 
         
            +
                    for dataset in concat_dataset.datasets:
         
     | 
| 60 | 
         
            +
                        dataset_len = len(dataset)
         
     | 
| 61 | 
         
            +
                        dataset_name = dataset.get_dataset_name()
         
     | 
| 62 | 
         
            +
                        if dataset_len < batch_size:
         
     | 
| 63 | 
         
            +
                            affected_dataset_name.append(dataset_name)
         
     | 
| 64 | 
         
            +
                            affected_dataset_len.append(dataset_len)
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                    self.type = type
         
     | 
| 67 | 
         
            +
                    for dataset_name, dataset_len in zip(
         
     | 
| 68 | 
         
            +
                        affected_dataset_name, affected_dataset_len
         
     | 
| 69 | 
         
            +
                    ):
         
     | 
| 70 | 
         
            +
                        if not type == "valid":
         
     | 
| 71 | 
         
            +
                            logger.warning(
         
     | 
| 72 | 
         
            +
                                "The {} dataset {} has a length of {}, which is smaller than the batch size {}. This may cause unexpected behavior.".format(
         
     | 
| 73 | 
         
            +
                                    type, dataset_name, dataset_len, batch_size
         
     | 
| 74 | 
         
            +
                                )
         
     | 
| 75 | 
         
            +
                            )
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
                def __len__(self):
         
     | 
| 78 | 
         
            +
                    # the number of batches with drop last
         
     | 
| 79 | 
         
            +
                    num_of_batches = sum(
         
     | 
| 80 | 
         
            +
                        [
         
     | 
| 81 | 
         
            +
                            math.floor(len(dataset) / self.batch_size)
         
     | 
| 82 | 
         
            +
                            for dataset in self.concat_dataset.datasets
         
     | 
| 83 | 
         
            +
                        ]
         
     | 
| 84 | 
         
            +
                    )
         
     | 
| 85 | 
         
            +
                    return num_of_batches * self.batch_size
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
                def __iter__(self):
         
     | 
| 88 | 
         
            +
                    iters = []
         
     | 
| 89 | 
         
            +
                    for dataset in self.concat_dataset.datasets:
         
     | 
| 90 | 
         
            +
                        iters.append(
         
     | 
| 91 | 
         
            +
                            SequentialSampler(dataset).__iter__()
         
     | 
| 92 | 
         
            +
                            if self.holistic_shuffle
         
     | 
| 93 | 
         
            +
                            else RandomSampler(dataset).__iter__()
         
     | 
| 94 | 
         
            +
                        )
         
     | 
| 95 | 
         
            +
                    init_indices = [0] + self.concat_dataset.cumulative_sizes[:-1]
         
     | 
| 96 | 
         
            +
                    output_batches = []
         
     | 
| 97 | 
         
            +
                    for dataset_idx in range(len(self.concat_dataset.datasets)):
         
     | 
| 98 | 
         
            +
                        cur_batch = []
         
     | 
| 99 | 
         
            +
                        for idx in iters[dataset_idx]:
         
     | 
| 100 | 
         
            +
                            cur_batch.append(idx + init_indices[dataset_idx])
         
     | 
| 101 | 
         
            +
                            if len(cur_batch) == self.batch_size:
         
     | 
| 102 | 
         
            +
                                output_batches.append(cur_batch)
         
     | 
| 103 | 
         
            +
                                cur_batch = []
         
     | 
| 104 | 
         
            +
                            if self.type == "valid" and len(cur_batch) > 0:
         
     | 
| 105 | 
         
            +
                                output_batches.append(cur_batch)
         
     | 
| 106 | 
         
            +
                                cur_batch = []
         
     | 
| 107 | 
         
            +
                    # force drop last in training
         
     | 
| 108 | 
         
            +
                    random.shuffle(output_batches)
         
     | 
| 109 | 
         
            +
                    output_indices = [item for sublist in output_batches for item in sublist]
         
     | 
| 110 | 
         
            +
                    return iter(output_indices)
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
            def build_samplers(concat_dataset: Dataset, cfg, logger, type):
         
     | 
| 114 | 
         
            +
                sampler = ScheduledSampler(
         
     | 
| 115 | 
         
            +
                    concat_dataset,
         
     | 
| 116 | 
         
            +
                    cfg.train.batch_size,
         
     | 
| 117 | 
         
            +
                    cfg.train.sampler.holistic_shuffle,
         
     | 
| 118 | 
         
            +
                    logger,
         
     | 
| 119 | 
         
            +
                    type,
         
     | 
| 120 | 
         
            +
                )
         
     | 
| 121 | 
         
            +
                batch_sampler = BatchSampler(
         
     | 
| 122 | 
         
            +
                    sampler,
         
     | 
| 123 | 
         
            +
                    cfg.train.batch_size,
         
     | 
| 124 | 
         
            +
                    cfg.train.sampler.drop_last if not type == "valid" else False,
         
     | 
| 125 | 
         
            +
                )
         
     | 
| 126 | 
         
            +
                return sampler, batch_sampler
         
     | 
    	
        models/codec/codec_trainer.py
    ADDED
    
    | 
         @@ -0,0 +1,166 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2023 Amphion.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import os
         
     | 
| 7 | 
         
            +
            import random
         
     | 
| 8 | 
         
            +
            from pathlib import Path
         
     | 
| 9 | 
         
            +
            import re
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            import accelerate
         
     | 
| 12 | 
         
            +
            import json5
         
     | 
| 13 | 
         
            +
            import numpy as np
         
     | 
| 14 | 
         
            +
            import torch
         
     | 
| 15 | 
         
            +
            from accelerate.utils import ProjectConfiguration
         
     | 
| 16 | 
         
            +
            from torch.utils.data import DataLoader
         
     | 
| 17 | 
         
            +
            from tqdm import tqdm
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            from models.codec.codec_sampler import build_samplers
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            class CodecTrainer:
         
     | 
| 23 | 
         
            +
                def __init__(self):
         
     | 
| 24 | 
         
            +
                    super().__init__()
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
                def _init_accelerator(self):
         
     | 
| 27 | 
         
            +
                    """Initialize the accelerator components."""
         
     | 
| 28 | 
         
            +
                    self.exp_dir = os.path.join(
         
     | 
| 29 | 
         
            +
                        os.path.abspath(self.cfg.log_dir), self.args.exp_name
         
     | 
| 30 | 
         
            +
                    )
         
     | 
| 31 | 
         
            +
                    project_config = ProjectConfiguration(
         
     | 
| 32 | 
         
            +
                        project_dir=self.exp_dir, logging_dir=os.path.join(self.exp_dir, "log")
         
     | 
| 33 | 
         
            +
                    )
         
     | 
| 34 | 
         
            +
                    self.accelerator = accelerate.Accelerator(
         
     | 
| 35 | 
         
            +
                        gradient_accumulation_steps=self.cfg.train.gradient_accumulation_step,
         
     | 
| 36 | 
         
            +
                        log_with=self.cfg.train.tracker,
         
     | 
| 37 | 
         
            +
                        project_config=project_config,
         
     | 
| 38 | 
         
            +
                    )
         
     | 
| 39 | 
         
            +
                    if self.accelerator.is_main_process:
         
     | 
| 40 | 
         
            +
                        os.makedirs(project_config.project_dir, exist_ok=True)
         
     | 
| 41 | 
         
            +
                        os.makedirs(project_config.logging_dir, exist_ok=True)
         
     | 
| 42 | 
         
            +
                    with self.accelerator.main_process_first():
         
     | 
| 43 | 
         
            +
                        self.accelerator.init_trackers(self.args.exp_name)
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
                def _build_dataset(self):
         
     | 
| 46 | 
         
            +
                    pass
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                def _build_criterion(self):
         
     | 
| 49 | 
         
            +
                    pass
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                def _build_model(self):
         
     | 
| 52 | 
         
            +
                    pass
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                def _build_dataloader(self):
         
     | 
| 55 | 
         
            +
                    """Build dataloader which merges a series of datasets."""
         
     | 
| 56 | 
         
            +
                    # Build dataset instance for each dataset and combine them by ConcatDataset
         
     | 
| 57 | 
         
            +
                    Dataset, Collator = self._build_dataset()
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                    # Build train set
         
     | 
| 60 | 
         
            +
                    train_dataset = Dataset(self.cfg, self.cfg.dataset, is_valid=False)
         
     | 
| 61 | 
         
            +
                    train_collate = Collator(self.cfg)
         
     | 
| 62 | 
         
            +
                    sampler = torch.utils.data.distributed.DistributedSampler(
         
     | 
| 63 | 
         
            +
                        train_dataset,
         
     | 
| 64 | 
         
            +
                        num_replicas=self.accelerator.num_processes,
         
     | 
| 65 | 
         
            +
                        rank=self.accelerator.local_process_index,
         
     | 
| 66 | 
         
            +
                        shuffle=True,
         
     | 
| 67 | 
         
            +
                        seed=self.cfg.train.random_seed,
         
     | 
| 68 | 
         
            +
                    )
         
     | 
| 69 | 
         
            +
                    train_loader = DataLoader(
         
     | 
| 70 | 
         
            +
                        train_dataset,
         
     | 
| 71 | 
         
            +
                        batch_size=self.cfg.train.batch_size,
         
     | 
| 72 | 
         
            +
                        collate_fn=train_collate,
         
     | 
| 73 | 
         
            +
                        sampler=sampler,
         
     | 
| 74 | 
         
            +
                        num_workers=self.cfg.train.dataloader.num_worker,
         
     | 
| 75 | 
         
            +
                        pin_memory=self.cfg.train.dataloader.pin_memory,
         
     | 
| 76 | 
         
            +
                    )
         
     | 
| 77 | 
         
            +
                    return train_loader, None
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
                def _build_optimizer(self):
         
     | 
| 80 | 
         
            +
                    pass
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                def _build_scheduler(self):
         
     | 
| 83 | 
         
            +
                    pass
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
                def _load_model(self, checkpoint_dir, checkpoint_path=None, resume_type="resume"):
         
     | 
| 86 | 
         
            +
                    """Load model from checkpoint. If a folder is given, it will
         
     | 
| 87 | 
         
            +
                    load the latest checkpoint in checkpoint_dir. If a path is given
         
     | 
| 88 | 
         
            +
                    it will load the checkpoint specified by checkpoint_path.
         
     | 
| 89 | 
         
            +
                    **Only use this method after** ``accelerator.prepare()``.
         
     | 
| 90 | 
         
            +
                    """
         
     | 
| 91 | 
         
            +
                    if checkpoint_path is None:
         
     | 
| 92 | 
         
            +
                        ls = [str(i) for i in Path(checkpoint_dir).glob("*")]
         
     | 
| 93 | 
         
            +
                        ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True)
         
     | 
| 94 | 
         
            +
                        checkpoint_path = ls[0]
         
     | 
| 95 | 
         
            +
                    if resume_type == "resume":
         
     | 
| 96 | 
         
            +
                        self.accelerator.load_state(checkpoint_path)
         
     | 
| 97 | 
         
            +
                    elif resume_type == "finetune":
         
     | 
| 98 | 
         
            +
                        accelerate.load_checkpoint_and_dispatch(
         
     | 
| 99 | 
         
            +
                            self.accelerator.unwrap_model(self.model),
         
     | 
| 100 | 
         
            +
                            os.path.join(checkpoint_path, "pytorch_model.bin"),
         
     | 
| 101 | 
         
            +
                        )
         
     | 
| 102 | 
         
            +
                        self.logger.info("Load model weights for finetune SUCCESS!")
         
     | 
| 103 | 
         
            +
                    else:
         
     | 
| 104 | 
         
            +
                        raise ValueError("Unsupported resume type: {}".format(resume_type))
         
     | 
| 105 | 
         
            +
                    self.epoch = int(checkpoint_path.split("_")[-3].split("-")[-1]) + 1
         
     | 
| 106 | 
         
            +
                    self.step = int(checkpoint_path.split("_")[-2].split("-")[-1]) + 1
         
     | 
| 107 | 
         
            +
                    return checkpoint_path
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
                def train_loop(self):
         
     | 
| 110 | 
         
            +
                    pass
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                def _train_epoch(self):
         
     | 
| 113 | 
         
            +
                    pass
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
                def _valid_epoch(self):
         
     | 
| 116 | 
         
            +
                    pass
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
                def _train_step(self):
         
     | 
| 119 | 
         
            +
                    pass
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
                def _valid_step(self):
         
     | 
| 122 | 
         
            +
                    pass
         
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
                def _inference(self):
         
     | 
| 125 | 
         
            +
                    pass
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                def _set_random_seed(self, seed):
         
     | 
| 128 | 
         
            +
                    """Set random seed for all possible random modules."""
         
     | 
| 129 | 
         
            +
                    random.seed(seed)
         
     | 
| 130 | 
         
            +
                    np.random.seed(seed)
         
     | 
| 131 | 
         
            +
                    torch.random.manual_seed(seed)
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                def _check_nan(self, loss):
         
     | 
| 134 | 
         
            +
                    if torch.any(torch.isnan(loss)):
         
     | 
| 135 | 
         
            +
                        self.logger.fatal("Fatal Error: NaN!")
         
     | 
| 136 | 
         
            +
                        self.logger.error("loss = {:.6f}".format(loss.item()), in_order=True)
         
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
                def _check_basic_configs(self):
         
     | 
| 139 | 
         
            +
                    if self.cfg.train.gradient_accumulation_step <= 0:
         
     | 
| 140 | 
         
            +
                        self.logger.fatal("Invalid gradient_accumulation_step value!")
         
     | 
| 141 | 
         
            +
                        self.logger.error(
         
     | 
| 142 | 
         
            +
                            f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
         
     | 
| 143 | 
         
            +
                        )
         
     | 
| 144 | 
         
            +
                        self.accelerator.end_training()
         
     | 
| 145 | 
         
            +
                        raise ValueError(
         
     | 
| 146 | 
         
            +
                            f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
         
     | 
| 147 | 
         
            +
                        )
         
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
                def _count_parameters(self):
         
     | 
| 150 | 
         
            +
                    pass
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
                def _dump_cfg(self, path):
         
     | 
| 153 | 
         
            +
                    os.makedirs(os.path.dirname(path), exist_ok=True)
         
     | 
| 154 | 
         
            +
                    json5.dump(
         
     | 
| 155 | 
         
            +
                        self.cfg,
         
     | 
| 156 | 
         
            +
                        open(path, "w"),
         
     | 
| 157 | 
         
            +
                        indent=4,
         
     | 
| 158 | 
         
            +
                        sort_keys=True,
         
     | 
| 159 | 
         
            +
                        ensure_ascii=False,
         
     | 
| 160 | 
         
            +
                        quote_keys=True,
         
     | 
| 161 | 
         
            +
                    )
         
     | 
| 162 | 
         
            +
             
     | 
| 163 | 
         
            +
                def _is_valid_pattern(self, directory_name):
         
     | 
| 164 | 
         
            +
                    directory_name = str(directory_name)
         
     | 
| 165 | 
         
            +
                    pattern = r"^epoch-\d{4}_step-\d{7}_loss-\d{1}\.\d{6}"
         
     | 
| 166 | 
         
            +
                    return re.match(pattern, directory_name) is not None
         
     | 
    	
        models/codec/facodec/__init__.py
    ADDED
    
    | 
         
            File without changes
         
     | 
    	
        models/codec/facodec/alias_free_torch/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,5 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            from .filter import *
         
     | 
| 4 | 
         
            +
            from .resample import *
         
     | 
| 5 | 
         
            +
            from .act import *
         
     | 
    	
        models/codec/facodec/alias_free_torch/act.py
    ADDED
    
    | 
         @@ -0,0 +1,29 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import torch.nn as nn
         
     | 
| 4 | 
         
            +
            from .resample import UpSample1d, DownSample1d
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            class Activation1d(nn.Module):
         
     | 
| 8 | 
         
            +
                def __init__(
         
     | 
| 9 | 
         
            +
                    self,
         
     | 
| 10 | 
         
            +
                    activation,
         
     | 
| 11 | 
         
            +
                    up_ratio: int = 2,
         
     | 
| 12 | 
         
            +
                    down_ratio: int = 2,
         
     | 
| 13 | 
         
            +
                    up_kernel_size: int = 12,
         
     | 
| 14 | 
         
            +
                    down_kernel_size: int = 12,
         
     | 
| 15 | 
         
            +
                ):
         
     | 
| 16 | 
         
            +
                    super().__init__()
         
     | 
| 17 | 
         
            +
                    self.up_ratio = up_ratio
         
     | 
| 18 | 
         
            +
                    self.down_ratio = down_ratio
         
     | 
| 19 | 
         
            +
                    self.act = activation
         
     | 
| 20 | 
         
            +
                    self.upsample = UpSample1d(up_ratio, up_kernel_size)
         
     | 
| 21 | 
         
            +
                    self.downsample = DownSample1d(down_ratio, down_kernel_size)
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
                # x: [B,C,T]
         
     | 
| 24 | 
         
            +
                def forward(self, x):
         
     | 
| 25 | 
         
            +
                    x = self.upsample(x)
         
     | 
| 26 | 
         
            +
                    x = self.act(x)
         
     | 
| 27 | 
         
            +
                    x = self.downsample(x)
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                    return x
         
     | 
    	
        models/codec/facodec/alias_free_torch/filter.py
    ADDED
    
    | 
         @@ -0,0 +1,96 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import torch
         
     | 
| 4 | 
         
            +
            import torch.nn as nn
         
     | 
| 5 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 6 | 
         
            +
            import math
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            if "sinc" in dir(torch):
         
     | 
| 9 | 
         
            +
                sinc = torch.sinc
         
     | 
| 10 | 
         
            +
            else:
         
     | 
| 11 | 
         
            +
                # This code is adopted from adefossez's julius.core.sinc under the MIT License
         
     | 
| 12 | 
         
            +
                # https://adefossez.github.io/julius/julius/core.html
         
     | 
| 13 | 
         
            +
                def sinc(x: torch.Tensor):
         
     | 
| 14 | 
         
            +
                    """
         
     | 
| 15 | 
         
            +
                    Implementation of sinc, i.e. sin(pi * x) / (pi * x)
         
     | 
| 16 | 
         
            +
                    __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
         
     | 
| 17 | 
         
            +
                    """
         
     | 
| 18 | 
         
            +
                    return torch.where(
         
     | 
| 19 | 
         
            +
                        x == 0,
         
     | 
| 20 | 
         
            +
                        torch.tensor(1.0, device=x.device, dtype=x.dtype),
         
     | 
| 21 | 
         
            +
                        torch.sin(math.pi * x) / math.pi / x,
         
     | 
| 22 | 
         
            +
                    )
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
         
     | 
| 26 | 
         
            +
            # https://adefossez.github.io/julius/julius/lowpass.html
         
     | 
| 27 | 
         
            +
            def kaiser_sinc_filter1d(
         
     | 
| 28 | 
         
            +
                cutoff, half_width, kernel_size
         
     | 
| 29 | 
         
            +
            ):  # return filter [1,1,kernel_size]
         
     | 
| 30 | 
         
            +
                even = kernel_size % 2 == 0
         
     | 
| 31 | 
         
            +
                half_size = kernel_size // 2
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
                # For kaiser window
         
     | 
| 34 | 
         
            +
                delta_f = 4 * half_width
         
     | 
| 35 | 
         
            +
                A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
         
     | 
| 36 | 
         
            +
                if A > 50.0:
         
     | 
| 37 | 
         
            +
                    beta = 0.1102 * (A - 8.7)
         
     | 
| 38 | 
         
            +
                elif A >= 21.0:
         
     | 
| 39 | 
         
            +
                    beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
         
     | 
| 40 | 
         
            +
                else:
         
     | 
| 41 | 
         
            +
                    beta = 0.0
         
     | 
| 42 | 
         
            +
                window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
         
     | 
| 45 | 
         
            +
                if even:
         
     | 
| 46 | 
         
            +
                    time = torch.arange(-half_size, half_size) + 0.5
         
     | 
| 47 | 
         
            +
                else:
         
     | 
| 48 | 
         
            +
                    time = torch.arange(kernel_size) - half_size
         
     | 
| 49 | 
         
            +
                if cutoff == 0:
         
     | 
| 50 | 
         
            +
                    filter_ = torch.zeros_like(time)
         
     | 
| 51 | 
         
            +
                else:
         
     | 
| 52 | 
         
            +
                    filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
         
     | 
| 53 | 
         
            +
                    # Normalize filter to have sum = 1, otherwise we will have a small leakage
         
     | 
| 54 | 
         
            +
                    # of the constant component in the input signal.
         
     | 
| 55 | 
         
            +
                    filter_ /= filter_.sum()
         
     | 
| 56 | 
         
            +
                    filter = filter_.view(1, 1, kernel_size)
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                return filter
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
            class LowPassFilter1d(nn.Module):
         
     | 
| 62 | 
         
            +
                def __init__(
         
     | 
| 63 | 
         
            +
                    self,
         
     | 
| 64 | 
         
            +
                    cutoff=0.5,
         
     | 
| 65 | 
         
            +
                    half_width=0.6,
         
     | 
| 66 | 
         
            +
                    stride: int = 1,
         
     | 
| 67 | 
         
            +
                    padding: bool = True,
         
     | 
| 68 | 
         
            +
                    padding_mode: str = "replicate",
         
     | 
| 69 | 
         
            +
                    kernel_size: int = 12,
         
     | 
| 70 | 
         
            +
                ):
         
     | 
| 71 | 
         
            +
                    # kernel_size should be even number for stylegan3 setup,
         
     | 
| 72 | 
         
            +
                    # in this implementation, odd number is also possible.
         
     | 
| 73 | 
         
            +
                    super().__init__()
         
     | 
| 74 | 
         
            +
                    if cutoff < -0.0:
         
     | 
| 75 | 
         
            +
                        raise ValueError("Minimum cutoff must be larger than zero.")
         
     | 
| 76 | 
         
            +
                    if cutoff > 0.5:
         
     | 
| 77 | 
         
            +
                        raise ValueError("A cutoff above 0.5 does not make sense.")
         
     | 
| 78 | 
         
            +
                    self.kernel_size = kernel_size
         
     | 
| 79 | 
         
            +
                    self.even = kernel_size % 2 == 0
         
     | 
| 80 | 
         
            +
                    self.pad_left = kernel_size // 2 - int(self.even)
         
     | 
| 81 | 
         
            +
                    self.pad_right = kernel_size // 2
         
     | 
| 82 | 
         
            +
                    self.stride = stride
         
     | 
| 83 | 
         
            +
                    self.padding = padding
         
     | 
| 84 | 
         
            +
                    self.padding_mode = padding_mode
         
     | 
| 85 | 
         
            +
                    filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
         
     | 
| 86 | 
         
            +
                    self.register_buffer("filter", filter)
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                # input [B, C, T]
         
     | 
| 89 | 
         
            +
                def forward(self, x):
         
     | 
| 90 | 
         
            +
                    _, C, _ = x.shape
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                    if self.padding:
         
     | 
| 93 | 
         
            +
                        x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
         
     | 
| 94 | 
         
            +
                    out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
                    return out
         
     | 
    	
        models/codec/facodec/alias_free_torch/resample.py
    ADDED
    
    | 
         @@ -0,0 +1,57 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import torch.nn as nn
         
     | 
| 4 | 
         
            +
            from torch.nn import functional as F
         
     | 
| 5 | 
         
            +
            from .filter import LowPassFilter1d
         
     | 
| 6 | 
         
            +
            from .filter import kaiser_sinc_filter1d
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            class UpSample1d(nn.Module):
         
     | 
| 10 | 
         
            +
                def __init__(self, ratio=2, kernel_size=None):
         
     | 
| 11 | 
         
            +
                    super().__init__()
         
     | 
| 12 | 
         
            +
                    self.ratio = ratio
         
     | 
| 13 | 
         
            +
                    self.kernel_size = (
         
     | 
| 14 | 
         
            +
                        int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
         
     | 
| 15 | 
         
            +
                    )
         
     | 
| 16 | 
         
            +
                    self.stride = ratio
         
     | 
| 17 | 
         
            +
                    self.pad = self.kernel_size // ratio - 1
         
     | 
| 18 | 
         
            +
                    self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
         
     | 
| 19 | 
         
            +
                    self.pad_right = (
         
     | 
| 20 | 
         
            +
                        self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
         
     | 
| 21 | 
         
            +
                    )
         
     | 
| 22 | 
         
            +
                    filter = kaiser_sinc_filter1d(
         
     | 
| 23 | 
         
            +
                        cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
         
     | 
| 24 | 
         
            +
                    )
         
     | 
| 25 | 
         
            +
                    self.register_buffer("filter", filter)
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                # x: [B, C, T]
         
     | 
| 28 | 
         
            +
                def forward(self, x):
         
     | 
| 29 | 
         
            +
                    _, C, _ = x.shape
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                    x = F.pad(x, (self.pad, self.pad), mode="replicate")
         
     | 
| 32 | 
         
            +
                    x = self.ratio * F.conv_transpose1d(
         
     | 
| 33 | 
         
            +
                        x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
         
     | 
| 34 | 
         
            +
                    )
         
     | 
| 35 | 
         
            +
                    x = x[..., self.pad_left : -self.pad_right]
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                    return x
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
            class DownSample1d(nn.Module):
         
     | 
| 41 | 
         
            +
                def __init__(self, ratio=2, kernel_size=None):
         
     | 
| 42 | 
         
            +
                    super().__init__()
         
     | 
| 43 | 
         
            +
                    self.ratio = ratio
         
     | 
| 44 | 
         
            +
                    self.kernel_size = (
         
     | 
| 45 | 
         
            +
                        int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
         
     | 
| 46 | 
         
            +
                    )
         
     | 
| 47 | 
         
            +
                    self.lowpass = LowPassFilter1d(
         
     | 
| 48 | 
         
            +
                        cutoff=0.5 / ratio,
         
     | 
| 49 | 
         
            +
                        half_width=0.6 / ratio,
         
     | 
| 50 | 
         
            +
                        stride=ratio,
         
     | 
| 51 | 
         
            +
                        kernel_size=self.kernel_size,
         
     | 
| 52 | 
         
            +
                    )
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                def forward(self, x):
         
     | 
| 55 | 
         
            +
                    xx = self.lowpass(x)
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                    return xx
         
     | 
    	
        models/codec/facodec/facodec_dataset.py
    ADDED
    
    | 
         @@ -0,0 +1,98 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2023 Amphion.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import torch
         
     | 
| 7 | 
         
            +
            import random
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            import numpy as np
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            import torchaudio
         
     | 
| 12 | 
         
            +
            import librosa
         
     | 
| 13 | 
         
            +
            from torch.nn import functional as F
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            from torch.nn.utils.rnn import pad_sequence
         
     | 
| 16 | 
         
            +
            from utils.data_utils import *
         
     | 
| 17 | 
         
            +
            from models.codec.codec_dataset import CodecDataset
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            class FAcodecDataset(torch.utils.data.Dataset):
         
     | 
| 21 | 
         
            +
                def __init__(self, cfg, dataset, is_valid=False):
         
     | 
| 22 | 
         
            +
                    """
         
     | 
| 23 | 
         
            +
                    Args:
         
     | 
| 24 | 
         
            +
                        cfg: config
         
     | 
| 25 | 
         
            +
                        dataset: dataset name
         
     | 
| 26 | 
         
            +
                        is_valid: whether to use train or valid dataset
         
     | 
| 27 | 
         
            +
                    """
         
     | 
| 28 | 
         
            +
                    self.data_root_dir = cfg.dataset
         
     | 
| 29 | 
         
            +
                    self.data_list = []
         
     | 
| 30 | 
         
            +
                    # walk through the dataset directory recursively, save all files ends with .wav/.mp3/.opus/.flac/.m4a
         
     | 
| 31 | 
         
            +
                    for root, _, files in os.walk(self.data_root_dir):
         
     | 
| 32 | 
         
            +
                        for file in files:
         
     | 
| 33 | 
         
            +
                            if file.endswith((".wav", ".mp3", ".opus", ".flac", ".m4a")):
         
     | 
| 34 | 
         
            +
                                self.data_list.append(os.path.join(root, file))
         
     | 
| 35 | 
         
            +
                    self.sr = cfg.preprocess_params.sr
         
     | 
| 36 | 
         
            +
                    self.duration_range = cfg.preprocess_params.duration_range
         
     | 
| 37 | 
         
            +
                    self.to_mel = torchaudio.transforms.MelSpectrogram(
         
     | 
| 38 | 
         
            +
                        n_mels=cfg.preprocess_params.spect_params.n_mels,
         
     | 
| 39 | 
         
            +
                        n_fft=cfg.preprocess_params.spect_params.n_fft,
         
     | 
| 40 | 
         
            +
                        win_length=cfg.preprocess_params.spect_params.win_length,
         
     | 
| 41 | 
         
            +
                        hop_length=cfg.preprocess_params.spect_params.hop_length,
         
     | 
| 42 | 
         
            +
                    )
         
     | 
| 43 | 
         
            +
                    self.mean, self.std = -4, 4
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
                def preprocess(self, wave):
         
     | 
| 46 | 
         
            +
                    wave_tensor = (
         
     | 
| 47 | 
         
            +
                        torch.from_numpy(wave).float() if isinstance(wave, np.ndarray) else wave
         
     | 
| 48 | 
         
            +
                    )
         
     | 
| 49 | 
         
            +
                    mel_tensor = self.to_mel(wave_tensor)
         
     | 
| 50 | 
         
            +
                    mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - self.mean) / self.std
         
     | 
| 51 | 
         
            +
                    return mel_tensor
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                def __len__(self):
         
     | 
| 54 | 
         
            +
                    # return len(self.data_list)
         
     | 
| 55 | 
         
            +
                    return len(self.data_list)  # return a fixed number for testing
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                def __getitem__(self, index):
         
     | 
| 58 | 
         
            +
                    wave, _ = librosa.load(self.data_list[index], sr=self.sr)
         
     | 
| 59 | 
         
            +
                    wave = np.random.randn(self.sr * random.randint(*self.duration_range))
         
     | 
| 60 | 
         
            +
                    wave = wave / np.max(np.abs(wave))
         
     | 
| 61 | 
         
            +
                    mel = self.preprocess(wave).squeeze(0)
         
     | 
| 62 | 
         
            +
                    wave = torch.from_numpy(wave).float()
         
     | 
| 63 | 
         
            +
                    return wave, mel
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
            class FAcodecCollator(object):
         
     | 
| 67 | 
         
            +
                """Zero-pads model inputs and targets based on number of frames per step"""
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
                def __init__(self, cfg):
         
     | 
| 70 | 
         
            +
                    self.cfg = cfg
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                def __call__(self, batch):
         
     | 
| 73 | 
         
            +
                    # batch[0] = wave, mel, text, f0, speakerid
         
     | 
| 74 | 
         
            +
                    batch_size = len(batch)
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                    # sort by mel length
         
     | 
| 77 | 
         
            +
                    lengths = [b[1].shape[1] for b in batch]
         
     | 
| 78 | 
         
            +
                    batch_indexes = np.argsort(lengths)[::-1]
         
     | 
| 79 | 
         
            +
                    batch = [batch[bid] for bid in batch_indexes]
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                    nmels = batch[0][1].size(0)
         
     | 
| 82 | 
         
            +
                    max_mel_length = max([b[1].shape[1] for b in batch])
         
     | 
| 83 | 
         
            +
                    max_wave_length = max([b[0].size(0) for b in batch])
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
                    mels = torch.zeros((batch_size, nmels, max_mel_length)).float() - 10
         
     | 
| 86 | 
         
            +
                    waves = torch.zeros((batch_size, max_wave_length)).float()
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                    mel_lengths = torch.zeros(batch_size).long()
         
     | 
| 89 | 
         
            +
                    wave_lengths = torch.zeros(batch_size).long()
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                    for bid, (wave, mel) in enumerate(batch):
         
     | 
| 92 | 
         
            +
                        mel_size = mel.size(1)
         
     | 
| 93 | 
         
            +
                        mels[bid, :, :mel_size] = mel
         
     | 
| 94 | 
         
            +
                        waves[bid, : wave.size(0)] = wave
         
     | 
| 95 | 
         
            +
                        mel_lengths[bid] = mel_size
         
     | 
| 96 | 
         
            +
                        wave_lengths[bid] = wave.size(0)
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
                    return waves, mels, wave_lengths, mel_lengths
         
     | 
    	
        models/codec/facodec/facodec_inference.py
    ADDED
    
    | 
         @@ -0,0 +1,137 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2023 Amphion.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import shutil
         
     | 
| 7 | 
         
            +
            import warnings
         
     | 
| 8 | 
         
            +
            import argparse
         
     | 
| 9 | 
         
            +
            import torch
         
     | 
| 10 | 
         
            +
            import os
         
     | 
| 11 | 
         
            +
            import yaml
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            warnings.simplefilter("ignore")
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            from .modules.commons import *
         
     | 
| 16 | 
         
            +
            import time
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            import torchaudio
         
     | 
| 19 | 
         
            +
            import librosa
         
     | 
| 20 | 
         
            +
            from collections import OrderedDict
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            class FAcodecInference(object):
         
     | 
| 24 | 
         
            +
                def __init__(self, args=None, cfg=None):
         
     | 
| 25 | 
         
            +
                    self.args = args
         
     | 
| 26 | 
         
            +
                    self.cfg = cfg
         
     | 
| 27 | 
         
            +
                    self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
         
     | 
| 28 | 
         
            +
                    self.model = self._build_model()
         
     | 
| 29 | 
         
            +
                    self._load_checkpoint()
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                def _build_model(self):
         
     | 
| 32 | 
         
            +
                    model = build_model(self.cfg.model_params)
         
     | 
| 33 | 
         
            +
                    _ = [model[key].to(self.device) for key in model]
         
     | 
| 34 | 
         
            +
                    return model
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
                def _load_checkpoint(self):
         
     | 
| 37 | 
         
            +
                    sd = torch.load(self.args.checkpoint_path, map_location="cpu")
         
     | 
| 38 | 
         
            +
                    sd = sd["net"] if "net" in sd else sd
         
     | 
| 39 | 
         
            +
                    new_params = dict()
         
     | 
| 40 | 
         
            +
                    for key, state_dict in sd.items():
         
     | 
| 41 | 
         
            +
                        new_state_dict = OrderedDict()
         
     | 
| 42 | 
         
            +
                        for k, v in state_dict.items():
         
     | 
| 43 | 
         
            +
                            if k.startswith("module."):
         
     | 
| 44 | 
         
            +
                                k = k[7:]
         
     | 
| 45 | 
         
            +
                            new_state_dict[k] = v
         
     | 
| 46 | 
         
            +
                        new_params[key] = new_state_dict
         
     | 
| 47 | 
         
            +
                    for key in new_params:
         
     | 
| 48 | 
         
            +
                        if key in self.model:
         
     | 
| 49 | 
         
            +
                            self.model[key].load_state_dict(new_params[key])
         
     | 
| 50 | 
         
            +
                    _ = [self.model[key].eval() for key in self.model]
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                @torch.no_grad()
         
     | 
| 53 | 
         
            +
                def inference(self, source, output_dir):
         
     | 
| 54 | 
         
            +
                    source_audio = librosa.load(source, sr=self.cfg.preprocess_params.sr)[0]
         
     | 
| 55 | 
         
            +
                    source_audio = torch.tensor(source_audio).unsqueeze(0).float().to(self.device)
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                    z = self.model.encoder(source_audio[None, ...].to(self.device).float())
         
     | 
| 58 | 
         
            +
                    (
         
     | 
| 59 | 
         
            +
                        z,
         
     | 
| 60 | 
         
            +
                        quantized,
         
     | 
| 61 | 
         
            +
                        commitment_loss,
         
     | 
| 62 | 
         
            +
                        codebook_loss,
         
     | 
| 63 | 
         
            +
                        timbre,
         
     | 
| 64 | 
         
            +
                        codes,
         
     | 
| 65 | 
         
            +
                    ) = self.model.quantizer(
         
     | 
| 66 | 
         
            +
                        z,
         
     | 
| 67 | 
         
            +
                        source_audio[None, ...].to(self.device).float(),
         
     | 
| 68 | 
         
            +
                        n_c=self.cfg.model_params.n_c_codebooks,
         
     | 
| 69 | 
         
            +
                        return_codes=True,
         
     | 
| 70 | 
         
            +
                    )
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                    full_pred_wave = self.model.decoder(z)
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                    os.makedirs(output_dir, exist_ok=True)
         
     | 
| 75 | 
         
            +
                    source_name = source.split("/")[-1].split(".")[0]
         
     | 
| 76 | 
         
            +
                    torchaudio.save(
         
     | 
| 77 | 
         
            +
                        f"{output_dir}/reconstructed_{source_name}.wav",
         
     | 
| 78 | 
         
            +
                        full_pred_wave[0].cpu(),
         
     | 
| 79 | 
         
            +
                        self.cfg.preprocess_params.sr,
         
     | 
| 80 | 
         
            +
                    )
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                    print(
         
     | 
| 83 | 
         
            +
                        "Reconstructed audio saved as: ",
         
     | 
| 84 | 
         
            +
                        f"{output_dir}/reconstructed_{source_name}.wav",
         
     | 
| 85 | 
         
            +
                    )
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
                    return quantized, codes
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                @torch.no_grad()
         
     | 
| 90 | 
         
            +
                def voice_conversion(self, source, reference, output_dir):
         
     | 
| 91 | 
         
            +
                    source_audio = librosa.load(source, sr=self.cfg.preprocess_params.sr)[0]
         
     | 
| 92 | 
         
            +
                    source_audio = torch.tensor(source_audio).unsqueeze(0).float().to(self.device)
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                    reference_audio = librosa.load(reference, sr=self.cfg.preprocess_params.sr)[0]
         
     | 
| 95 | 
         
            +
                    reference_audio = (
         
     | 
| 96 | 
         
            +
                        torch.tensor(reference_audio).unsqueeze(0).float().to(self.device)
         
     | 
| 97 | 
         
            +
                    )
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
                    z = self.model.encoder(source_audio[None, ...].to(self.device).float())
         
     | 
| 100 | 
         
            +
                    z, quantized, commitment_loss, codebook_loss, timbre = self.model.quantizer(
         
     | 
| 101 | 
         
            +
                        z,
         
     | 
| 102 | 
         
            +
                        source_audio[None, ...].to(self.device).float(),
         
     | 
| 103 | 
         
            +
                        n_c=self.cfg.model_params.n_c_codebooks,
         
     | 
| 104 | 
         
            +
                    )
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
                    z_ref = self.model.encoder(reference_audio[None, ...].to(self.device).float())
         
     | 
| 107 | 
         
            +
                    (
         
     | 
| 108 | 
         
            +
                        z_ref,
         
     | 
| 109 | 
         
            +
                        quantized_ref,
         
     | 
| 110 | 
         
            +
                        commitment_loss_ref,
         
     | 
| 111 | 
         
            +
                        codebook_loss_ref,
         
     | 
| 112 | 
         
            +
                        timbre_ref,
         
     | 
| 113 | 
         
            +
                    ) = self.model.quantizer(
         
     | 
| 114 | 
         
            +
                        z_ref,
         
     | 
| 115 | 
         
            +
                        reference_audio[None, ...].to(self.device).float(),
         
     | 
| 116 | 
         
            +
                        n_c=self.cfg.model_params.n_c_codebooks,
         
     | 
| 117 | 
         
            +
                    )
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                    z_conv = self.model.quantizer.voice_conversion(
         
     | 
| 120 | 
         
            +
                        quantized[0] + quantized[1],
         
     | 
| 121 | 
         
            +
                        reference_audio[None, ...].to(self.device).float(),
         
     | 
| 122 | 
         
            +
                    )
         
     | 
| 123 | 
         
            +
                    full_pred_wave = self.model.decoder(z_conv)
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
                    os.makedirs(output_dir, exist_ok=True)
         
     | 
| 126 | 
         
            +
                    source_name = source.split("/")[-1].split(".")[0]
         
     | 
| 127 | 
         
            +
                    reference_name = reference.split("/")[-1].split(".")[0]
         
     | 
| 128 | 
         
            +
                    torchaudio.save(
         
     | 
| 129 | 
         
            +
                        f"{output_dir}/converted_{source_name}_to_{reference_name}.wav",
         
     | 
| 130 | 
         
            +
                        full_pred_wave[0].cpu(),
         
     | 
| 131 | 
         
            +
                        self.cfg.preprocess_params.sr,
         
     | 
| 132 | 
         
            +
                    )
         
     | 
| 133 | 
         
            +
             
     | 
| 134 | 
         
            +
                    print(
         
     | 
| 135 | 
         
            +
                        "Voice conversion results saved as: ",
         
     | 
| 136 | 
         
            +
                        f"{output_dir}/converted_{source_name}_to_{reference_name}.wav",
         
     | 
| 137 | 
         
            +
                    )
         
     | 
    	
        models/codec/facodec/facodec_trainer.py
    ADDED
    
    | 
         @@ -0,0 +1,776 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2023 Amphion.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import os
         
     | 
| 7 | 
         
            +
            import time
         
     | 
| 8 | 
         
            +
            import random
         
     | 
| 9 | 
         
            +
            from pathlib import Path
         
     | 
| 10 | 
         
            +
            import re
         
     | 
| 11 | 
         
            +
            import glob
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            import accelerate
         
     | 
| 14 | 
         
            +
            import json
         
     | 
| 15 | 
         
            +
            import numpy as np
         
     | 
| 16 | 
         
            +
            import torch
         
     | 
| 17 | 
         
            +
            from accelerate.utils import ProjectConfiguration
         
     | 
| 18 | 
         
            +
            from torch.utils.data import DataLoader
         
     | 
| 19 | 
         
            +
            from tqdm import tqdm
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            import torch
         
     | 
| 22 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 23 | 
         
            +
            import torchaudio
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            from accelerate.logging import get_logger
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            from models.codec.facodec.facodec_dataset import FAcodecDataset, FAcodecCollator
         
     | 
| 28 | 
         
            +
            from models.codec.codec_sampler import build_samplers
         
     | 
| 29 | 
         
            +
            from models.codec.codec_trainer import CodecTrainer
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
            from modules.dac.nn.loss import (
         
     | 
| 32 | 
         
            +
                MultiScaleSTFTLoss,
         
     | 
| 33 | 
         
            +
                MelSpectrogramLoss,
         
     | 
| 34 | 
         
            +
                GANLoss,
         
     | 
| 35 | 
         
            +
                L1Loss,
         
     | 
| 36 | 
         
            +
                FocalLoss,
         
     | 
| 37 | 
         
            +
            )
         
     | 
| 38 | 
         
            +
            from audiotools import AudioSignal
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
            from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
            try:
         
     | 
| 43 | 
         
            +
                import nemo.collections.asr as nemo_asr
         
     | 
| 44 | 
         
            +
            except ImportError:
         
     | 
| 45 | 
         
            +
                print(
         
     | 
| 46 | 
         
            +
                    "Unable to import nemo_asr, titanet outputs will be set to random values, you may only run debugging mode. DO NOT USE THIS FOR TRAINING"
         
     | 
| 47 | 
         
            +
                )
         
     | 
| 48 | 
         
            +
                nemo_asr = None
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
            from models.codec.facodec.modules.commons import (
         
     | 
| 51 | 
         
            +
                build_model,
         
     | 
| 52 | 
         
            +
                load_checkpoint,
         
     | 
| 53 | 
         
            +
                load_F0_models,
         
     | 
| 54 | 
         
            +
                log_norm,
         
     | 
| 55 | 
         
            +
            )
         
     | 
| 56 | 
         
            +
            from models.codec.facodec.optimizer import build_optimizer
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
            class FAcodecTrainer(CodecTrainer):
         
     | 
| 60 | 
         
            +
                def __init__(self, args, cfg):
         
     | 
| 61 | 
         
            +
                    super().__init__()
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
                    self.args = args
         
     | 
| 64 | 
         
            +
                    self.cfg = cfg
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                    cfg.exp_name = args.exp_name
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
                    # Init accelerator
         
     | 
| 69 | 
         
            +
                    self._init_accelerator()
         
     | 
| 70 | 
         
            +
                    self.accelerator.wait_for_everyone()
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                    # Init logger
         
     | 
| 73 | 
         
            +
                    with self.accelerator.main_process_first():
         
     | 
| 74 | 
         
            +
                        self.logger = get_logger(args.exp_name, log_level=args.log_level)
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                    self.logger.info("=" * 56)
         
     | 
| 77 | 
         
            +
                    self.logger.info("||\t\t" + "New training process started." + "\t\t||")
         
     | 
| 78 | 
         
            +
                    self.logger.info("=" * 56)
         
     | 
| 79 | 
         
            +
                    self.logger.info("\n")
         
     | 
| 80 | 
         
            +
                    self.logger.debug(f"Using {args.log_level.upper()} logging level.")
         
     | 
| 81 | 
         
            +
                    self.logger.info(f"Experiment name: {args.exp_name}")
         
     | 
| 82 | 
         
            +
                    self.logger.info(f"Experiment directory: {self.exp_dir}")
         
     | 
| 83 | 
         
            +
                    self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint")
         
     | 
| 84 | 
         
            +
                    if self.accelerator.is_main_process:
         
     | 
| 85 | 
         
            +
                        os.makedirs(self.checkpoint_dir, exist_ok=True)
         
     | 
| 86 | 
         
            +
                    self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}")
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                    # Init training status
         
     | 
| 89 | 
         
            +
                    self.batch_count: int = 0
         
     | 
| 90 | 
         
            +
                    self.step: int = 0
         
     | 
| 91 | 
         
            +
                    self.epoch: int = 0
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
                    self.max_epoch = (
         
     | 
| 94 | 
         
            +
                        self.cfg.train.max_epoch if self.cfg.train.max_epoch > 0 else float("inf")
         
     | 
| 95 | 
         
            +
                    )
         
     | 
| 96 | 
         
            +
                    self.logger.info(
         
     | 
| 97 | 
         
            +
                        "Max epoch: {}".format(
         
     | 
| 98 | 
         
            +
                            self.max_epoch if self.max_epoch < float("inf") else "Unlimited"
         
     | 
| 99 | 
         
            +
                        )
         
     | 
| 100 | 
         
            +
                    )
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                    # Check potential erorrs
         
     | 
| 103 | 
         
            +
                    if self.accelerator.is_main_process:
         
     | 
| 104 | 
         
            +
                        self._check_basic_configs()
         
     | 
| 105 | 
         
            +
                        self.save_checkpoint_stride = self.cfg.train.save_checkpoint_stride
         
     | 
| 106 | 
         
            +
                        self.checkpoints_path = [
         
     | 
| 107 | 
         
            +
                            [] for _ in range(len(self.save_checkpoint_stride))
         
     | 
| 108 | 
         
            +
                        ]
         
     | 
| 109 | 
         
            +
                        self.run_eval = self.cfg.train.run_eval
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                    # Set random seed
         
     | 
| 112 | 
         
            +
                    with self.accelerator.main_process_first():
         
     | 
| 113 | 
         
            +
                        start = time.monotonic_ns()
         
     | 
| 114 | 
         
            +
                        self._set_random_seed(self.cfg.train.random_seed)
         
     | 
| 115 | 
         
            +
                        end = time.monotonic_ns()
         
     | 
| 116 | 
         
            +
                        self.logger.debug(
         
     | 
| 117 | 
         
            +
                            f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
         
     | 
| 118 | 
         
            +
                        )
         
     | 
| 119 | 
         
            +
                        self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
                    # Build dataloader
         
     | 
| 122 | 
         
            +
                    with self.accelerator.main_process_first():
         
     | 
| 123 | 
         
            +
                        self.logger.info("Building dataset...")
         
     | 
| 124 | 
         
            +
                        start = time.monotonic_ns()
         
     | 
| 125 | 
         
            +
                        self.train_dataloader, self.valid_dataloader = self._build_dataloader()
         
     | 
| 126 | 
         
            +
                        end = time.monotonic_ns()
         
     | 
| 127 | 
         
            +
                        self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms")
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
                    # Build model
         
     | 
| 130 | 
         
            +
                    with self.accelerator.main_process_first():
         
     | 
| 131 | 
         
            +
                        self.logger.info("Building model...")
         
     | 
| 132 | 
         
            +
                        start = time.monotonic_ns()
         
     | 
| 133 | 
         
            +
                        self.model = self._build_model()
         
     | 
| 134 | 
         
            +
                        end = time.monotonic_ns()
         
     | 
| 135 | 
         
            +
                        for _, model in self.model.items():
         
     | 
| 136 | 
         
            +
                            self.logger.debug(model)
         
     | 
| 137 | 
         
            +
                        self.logger.info(f"Building model done in {(end - start) / 1e6:.2f}ms")
         
     | 
| 138 | 
         
            +
                        self.logger.info(f"Model parameters: {self._count_parameters()/1e6:.2f}M")
         
     | 
| 139 | 
         
            +
             
     | 
| 140 | 
         
            +
                    # Build optimizers and schedulers
         
     | 
| 141 | 
         
            +
                    with self.accelerator.main_process_first():
         
     | 
| 142 | 
         
            +
                        self.logger.info("Building optimizer and scheduler...")
         
     | 
| 143 | 
         
            +
                        start = time.monotonic_ns()
         
     | 
| 144 | 
         
            +
                        self.optimizer = self._build_optimizer()
         
     | 
| 145 | 
         
            +
                        end = time.monotonic_ns()
         
     | 
| 146 | 
         
            +
                        self.logger.info(
         
     | 
| 147 | 
         
            +
                            f"Building optimizer and scheduler done in {(end - start) / 1e6:.2f}ms"
         
     | 
| 148 | 
         
            +
                        )
         
     | 
| 149 | 
         
            +
             
     | 
| 150 | 
         
            +
                    # Build helper models
         
     | 
| 151 | 
         
            +
                    with self.accelerator.main_process_first():
         
     | 
| 152 | 
         
            +
                        self.logger.info("Building helper models...")
         
     | 
| 153 | 
         
            +
                        start = time.monotonic_ns()
         
     | 
| 154 | 
         
            +
                        self._built_helper_model()
         
     | 
| 155 | 
         
            +
                        end = time.monotonic_ns()
         
     | 
| 156 | 
         
            +
                        self.logger.info(
         
     | 
| 157 | 
         
            +
                            f"Building helper models done in {(end - start) / 1e6:.2f}ms"
         
     | 
| 158 | 
         
            +
                        )
         
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
                    # Accelerator preparing
         
     | 
| 161 | 
         
            +
                    self.logger.info("Initializing accelerate...")
         
     | 
| 162 | 
         
            +
                    start = time.monotonic_ns()
         
     | 
| 163 | 
         
            +
                    for k in self.model:
         
     | 
| 164 | 
         
            +
                        self.model[k] = self.accelerator.prepare(self.model[k])
         
     | 
| 165 | 
         
            +
                    for k, v in self.optimizer.optimizers.items():
         
     | 
| 166 | 
         
            +
                        self.optimizer.optimizers[k] = self.accelerator.prepare(
         
     | 
| 167 | 
         
            +
                            self.optimizer.optimizers[k]
         
     | 
| 168 | 
         
            +
                        )
         
     | 
| 169 | 
         
            +
                        self.optimizer.schedulers[k] = self.accelerator.prepare(
         
     | 
| 170 | 
         
            +
                            self.optimizer.schedulers[k]
         
     | 
| 171 | 
         
            +
                        )
         
     | 
| 172 | 
         
            +
                    end = time.monotonic_ns()
         
     | 
| 173 | 
         
            +
                    self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.2f}ms")
         
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
            +
                    # Build criterions
         
     | 
| 176 | 
         
            +
                    with self.accelerator.main_process_first():
         
     | 
| 177 | 
         
            +
                        self.logger.info("Building criterion...")
         
     | 
| 178 | 
         
            +
                        start = time.monotonic_ns()
         
     | 
| 179 | 
         
            +
                        self.criterions = self._build_criterion()
         
     | 
| 180 | 
         
            +
                        end = time.monotonic_ns()
         
     | 
| 181 | 
         
            +
                        self.logger.info(f"Building criterion done in {(end - start) / 1e6:.2f}ms")
         
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
                    # Resume checkpoints
         
     | 
| 184 | 
         
            +
                    with self.accelerator.main_process_first():
         
     | 
| 185 | 
         
            +
                        self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint")
         
     | 
| 186 | 
         
            +
                        if args.resume_type:
         
     | 
| 187 | 
         
            +
                            self.logger.info("Resuming from checkpoint...")
         
     | 
| 188 | 
         
            +
                            start = time.monotonic_ns()
         
     | 
| 189 | 
         
            +
                            ckpt_path = Path(args.checkpoint)
         
     | 
| 190 | 
         
            +
                            if self._is_valid_pattern(ckpt_path.parts[-1]):
         
     | 
| 191 | 
         
            +
                                ckpt_path = self._load_model(args.checkpoint, args.resume_type)
         
     | 
| 192 | 
         
            +
                            else:
         
     | 
| 193 | 
         
            +
                                ckpt_path = self._load_model(
         
     | 
| 194 | 
         
            +
                                    args.checkpoint, resume_type=args.resume_type
         
     | 
| 195 | 
         
            +
                                )
         
     | 
| 196 | 
         
            +
                            end = time.monotonic_ns()
         
     | 
| 197 | 
         
            +
                            self.logger.info(
         
     | 
| 198 | 
         
            +
                                f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms"
         
     | 
| 199 | 
         
            +
                            )
         
     | 
| 200 | 
         
            +
                            self.checkpoints_path = json.load(
         
     | 
| 201 | 
         
            +
                                open(os.path.join(ckpt_path, "ckpts.json"), "r")
         
     | 
| 202 | 
         
            +
                            )
         
     | 
| 203 | 
         
            +
             
     | 
| 204 | 
         
            +
                        if self.accelerator.is_main_process:
         
     | 
| 205 | 
         
            +
                            os.makedirs(self.checkpoint_dir, exist_ok=True)
         
     | 
| 206 | 
         
            +
                        self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}")
         
     | 
| 207 | 
         
            +
             
     | 
| 208 | 
         
            +
                    # Save config
         
     | 
| 209 | 
         
            +
                    self.config_save_path = os.path.join(self.exp_dir, "args.json")
         
     | 
| 210 | 
         
            +
             
     | 
| 211 | 
         
            +
                def _build_dataset(self):
         
     | 
| 212 | 
         
            +
                    return FAcodecDataset, FAcodecCollator
         
     | 
| 213 | 
         
            +
             
     | 
| 214 | 
         
            +
                def _build_criterion(self):
         
     | 
| 215 | 
         
            +
                    criterions = dict()
         
     | 
| 216 | 
         
            +
                    stft_criterion = MultiScaleSTFTLoss()
         
     | 
| 217 | 
         
            +
                    mel_criterion = MelSpectrogramLoss(
         
     | 
| 218 | 
         
            +
                        n_mels=[5, 10, 20, 40, 80, 160, 320],
         
     | 
| 219 | 
         
            +
                        window_lengths=[32, 64, 128, 256, 512, 1024, 2048],
         
     | 
| 220 | 
         
            +
                        mel_fmin=[0, 0, 0, 0, 0, 0, 0],
         
     | 
| 221 | 
         
            +
                        mel_fmax=[None, None, None, None, None, None, None],
         
     | 
| 222 | 
         
            +
                        pow=1.0,
         
     | 
| 223 | 
         
            +
                        mag_weight=0.0,
         
     | 
| 224 | 
         
            +
                        clamp_eps=1e-5,
         
     | 
| 225 | 
         
            +
                    )
         
     | 
| 226 | 
         
            +
                    content_criterion = FocalLoss(gamma=2)
         
     | 
| 227 | 
         
            +
                    l1_criterion = L1Loss()
         
     | 
| 228 | 
         
            +
                    criterions["stft"] = stft_criterion
         
     | 
| 229 | 
         
            +
                    criterions["mel"] = mel_criterion
         
     | 
| 230 | 
         
            +
                    criterions["l1"] = l1_criterion
         
     | 
| 231 | 
         
            +
                    criterions["content"] = content_criterion
         
     | 
| 232 | 
         
            +
             
     | 
| 233 | 
         
            +
                    return criterions
         
     | 
| 234 | 
         
            +
             
     | 
| 235 | 
         
            +
                def _build_model(self):
         
     | 
| 236 | 
         
            +
                    model = build_model(self.cfg.model_params)
         
     | 
| 237 | 
         
            +
                    _ = [model[key].to(self.accelerator.device) for key in model]
         
     | 
| 238 | 
         
            +
                    return model
         
     | 
| 239 | 
         
            +
             
     | 
| 240 | 
         
            +
                def _built_helper_model(self):
         
     | 
| 241 | 
         
            +
                    device = self.accelerator.device
         
     | 
| 242 | 
         
            +
                    self.pitch_extractor = load_F0_models(self.cfg.F0_path).to(device)
         
     | 
| 243 | 
         
            +
             
     | 
| 244 | 
         
            +
                    # load model and processor
         
     | 
| 245 | 
         
            +
                    self.w2v_processor = Wav2Vec2Processor.from_pretrained(
         
     | 
| 246 | 
         
            +
                        "facebook/wav2vec2-xlsr-53-espeak-cv-ft"
         
     | 
| 247 | 
         
            +
                    )
         
     | 
| 248 | 
         
            +
                    self.w2v_model = Wav2Vec2ForCTC.from_pretrained(
         
     | 
| 249 | 
         
            +
                        "facebook/wav2vec2-xlsr-53-espeak-cv-ft"
         
     | 
| 250 | 
         
            +
                    ).to(device)
         
     | 
| 251 | 
         
            +
                    self.w2v_model.eval()
         
     | 
| 252 | 
         
            +
             
     | 
| 253 | 
         
            +
                    if nemo_asr is None:
         
     | 
| 254 | 
         
            +
                        self.speaker_model = None
         
     | 
| 255 | 
         
            +
                    else:
         
     | 
| 256 | 
         
            +
                        self.speaker_model = (
         
     | 
| 257 | 
         
            +
                            nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained(
         
     | 
| 258 | 
         
            +
                                "nvidia/speakerverification_en_titanet_large"
         
     | 
| 259 | 
         
            +
                            )
         
     | 
| 260 | 
         
            +
                        )
         
     | 
| 261 | 
         
            +
                        self.speaker_model = self.speaker_model.to(device)
         
     | 
| 262 | 
         
            +
                        self.speaker_model.eval()
         
     | 
| 263 | 
         
            +
             
     | 
| 264 | 
         
            +
                def _build_optimizer(self):
         
     | 
| 265 | 
         
            +
                    scheduler_params = {
         
     | 
| 266 | 
         
            +
                        "warmup_steps": self.cfg.loss_params.warmup_steps,
         
     | 
| 267 | 
         
            +
                        "base_lr": self.cfg.loss_params.base_lr,
         
     | 
| 268 | 
         
            +
                    }
         
     | 
| 269 | 
         
            +
                    optimizer = build_optimizer(
         
     | 
| 270 | 
         
            +
                        {key: self.model[key] for key in self.model},
         
     | 
| 271 | 
         
            +
                        scheduler_params_dict={key: scheduler_params.copy() for key in self.model},
         
     | 
| 272 | 
         
            +
                        lr=float(scheduler_params["base_lr"]),
         
     | 
| 273 | 
         
            +
                    )
         
     | 
| 274 | 
         
            +
             
     | 
| 275 | 
         
            +
                    return optimizer
         
     | 
| 276 | 
         
            +
             
     | 
| 277 | 
         
            +
                def train_loop(self):
         
     | 
| 278 | 
         
            +
                    """Training process"""
         
     | 
| 279 | 
         
            +
                    self.accelerator.wait_for_everyone()
         
     | 
| 280 | 
         
            +
             
     | 
| 281 | 
         
            +
                    # Dump config
         
     | 
| 282 | 
         
            +
                    if self.accelerator.is_main_process:
         
     | 
| 283 | 
         
            +
                        self._dump_cfg(self.config_save_path)
         
     | 
| 284 | 
         
            +
                    _ = [self.model[key].train() for key in self.model]
         
     | 
| 285 | 
         
            +
                    self.optimizer.zero_grad()
         
     | 
| 286 | 
         
            +
             
     | 
| 287 | 
         
            +
                    # Sync and start training
         
     | 
| 288 | 
         
            +
                    self.accelerator.wait_for_everyone()
         
     | 
| 289 | 
         
            +
                    while self.epoch < self.max_epoch:
         
     | 
| 290 | 
         
            +
                        self.logger.info("\n")
         
     | 
| 291 | 
         
            +
                        self.logger.info("-" * 32)
         
     | 
| 292 | 
         
            +
                        self.logger.info("Epoch {}: ".format(self.epoch))
         
     | 
| 293 | 
         
            +
             
     | 
| 294 | 
         
            +
                        # Train and Validate
         
     | 
| 295 | 
         
            +
                        train_total_loss, train_losses = self._train_epoch()
         
     | 
| 296 | 
         
            +
                        for key, loss in train_losses.items():
         
     | 
| 297 | 
         
            +
                            self.logger.info("  |- Train/{} Loss: {:.6f}".format(key, loss))
         
     | 
| 298 | 
         
            +
                            self.accelerator.log(
         
     | 
| 299 | 
         
            +
                                {"Epoch/Train {} Loss".format(key): loss},
         
     | 
| 300 | 
         
            +
                                step=self.epoch,
         
     | 
| 301 | 
         
            +
                            )
         
     | 
| 302 | 
         
            +
                        self.accelerator.log(
         
     | 
| 303 | 
         
            +
                            {
         
     | 
| 304 | 
         
            +
                                "Epoch/Train Total Loss": train_total_loss,
         
     | 
| 305 | 
         
            +
                            },
         
     | 
| 306 | 
         
            +
                            step=self.epoch,
         
     | 
| 307 | 
         
            +
                        )
         
     | 
| 308 | 
         
            +
             
     | 
| 309 | 
         
            +
                        # Update scheduler
         
     | 
| 310 | 
         
            +
                        self.accelerator.wait_for_everyone()
         
     | 
| 311 | 
         
            +
             
     | 
| 312 | 
         
            +
                        # Check save checkpoint interval
         
     | 
| 313 | 
         
            +
                        run_eval = False
         
     | 
| 314 | 
         
            +
                        if self.accelerator.is_main_process:
         
     | 
| 315 | 
         
            +
                            save_checkpoint = False
         
     | 
| 316 | 
         
            +
                            for i, num in enumerate(self.save_checkpoint_stride):
         
     | 
| 317 | 
         
            +
                                if self.epoch % num == 0:
         
     | 
| 318 | 
         
            +
                                    save_checkpoint = True
         
     | 
| 319 | 
         
            +
                                    run_eval |= self.run_eval[i]
         
     | 
| 320 | 
         
            +
             
     | 
| 321 | 
         
            +
                        # Save checkpoints
         
     | 
| 322 | 
         
            +
                        self.accelerator.wait_for_everyone()
         
     | 
| 323 | 
         
            +
                        if self.accelerator.is_main_process and save_checkpoint:
         
     | 
| 324 | 
         
            +
                            print("Saving..")
         
     | 
| 325 | 
         
            +
                            state = {
         
     | 
| 326 | 
         
            +
                                "net": {key: self.model[key].state_dict() for key in self.model},
         
     | 
| 327 | 
         
            +
                                "optimizer": self.optimizer.state_dict(),
         
     | 
| 328 | 
         
            +
                                "scheduler": self.optimizer.scheduler_state_dict(),
         
     | 
| 329 | 
         
            +
                                "iters": self.step,
         
     | 
| 330 | 
         
            +
                                "epoch": self.epoch,
         
     | 
| 331 | 
         
            +
                            }
         
     | 
| 332 | 
         
            +
                            save_path = os.path.join(
         
     | 
| 333 | 
         
            +
                                self.checkpoint_dir,
         
     | 
| 334 | 
         
            +
                                "FAcodec_epoch_%05d_step_%05d.pth" % (self.epoch, self.iters),
         
     | 
| 335 | 
         
            +
                            )
         
     | 
| 336 | 
         
            +
                            torch.save(state, save_path)
         
     | 
| 337 | 
         
            +
                            json.dump(
         
     | 
| 338 | 
         
            +
                                self.checkpoints_path,
         
     | 
| 339 | 
         
            +
                                open(os.path.join(self.checkpoint_dir, "ckpts.json"), "w"),
         
     | 
| 340 | 
         
            +
                                ensure_ascii=False,
         
     | 
| 341 | 
         
            +
                                indent=4,
         
     | 
| 342 | 
         
            +
                            )
         
     | 
| 343 | 
         
            +
             
     | 
| 344 | 
         
            +
                        self.accelerator.wait_for_everyone()
         
     | 
| 345 | 
         
            +
             
     | 
| 346 | 
         
            +
                        self.epoch += 1
         
     | 
| 347 | 
         
            +
             
     | 
| 348 | 
         
            +
                    # Finish training
         
     | 
| 349 | 
         
            +
                    self.accelerator.wait_for_everyone()
         
     | 
| 350 | 
         
            +
                    if self.accelerator.is_main_process:
         
     | 
| 351 | 
         
            +
                        path = os.path.join(
         
     | 
| 352 | 
         
            +
                            self.checkpoint_dir,
         
     | 
| 353 | 
         
            +
                            "epoch-{:04d}_step-{:07d}".format(
         
     | 
| 354 | 
         
            +
                                self.epoch,
         
     | 
| 355 | 
         
            +
                                self.step,
         
     | 
| 356 | 
         
            +
                            ),
         
     | 
| 357 | 
         
            +
                        )
         
     | 
| 358 | 
         
            +
                        print("Saving..")
         
     | 
| 359 | 
         
            +
                        state = {
         
     | 
| 360 | 
         
            +
                            "net": {key: self.model[key].state_dict() for key in self.model},
         
     | 
| 361 | 
         
            +
                            "optimizer": self.optimizer.state_dict(),
         
     | 
| 362 | 
         
            +
                            "scheduler": self.optimizer.scheduler_state_dict(),
         
     | 
| 363 | 
         
            +
                            "iters": self.step,
         
     | 
| 364 | 
         
            +
                            "epoch": self.epoch,
         
     | 
| 365 | 
         
            +
                        }
         
     | 
| 366 | 
         
            +
                        save_path = os.path.join(
         
     | 
| 367 | 
         
            +
                            self.checkpoint_dir,
         
     | 
| 368 | 
         
            +
                            "FAcodec_epoch_%05d_step_%05d.pth" % (self.epoch, self.iters),
         
     | 
| 369 | 
         
            +
                        )
         
     | 
| 370 | 
         
            +
                        torch.save(state, save_path)
         
     | 
| 371 | 
         
            +
             
     | 
| 372 | 
         
            +
                def _train_epoch(self):
         
     | 
| 373 | 
         
            +
                    """Training epoch. Should return average loss of a batch (sample) over
         
     | 
| 374 | 
         
            +
                    one epoch. See ``train_loop`` for usage.
         
     | 
| 375 | 
         
            +
                    """
         
     | 
| 376 | 
         
            +
                    _ = [self.model[key].train() for key in self.model]
         
     | 
| 377 | 
         
            +
             
     | 
| 378 | 
         
            +
                    epoch_losses: dict = {}
         
     | 
| 379 | 
         
            +
                    epoch_total_loss: int = 0
         
     | 
| 380 | 
         
            +
             
     | 
| 381 | 
         
            +
                    for batch in tqdm(
         
     | 
| 382 | 
         
            +
                        self.train_dataloader,
         
     | 
| 383 | 
         
            +
                        desc=f"Training Epoch {self.epoch}",
         
     | 
| 384 | 
         
            +
                        unit="batch",
         
     | 
| 385 | 
         
            +
                        colour="GREEN",
         
     | 
| 386 | 
         
            +
                        leave=False,
         
     | 
| 387 | 
         
            +
                        dynamic_ncols=True,
         
     | 
| 388 | 
         
            +
                        smoothing=0.04,
         
     | 
| 389 | 
         
            +
                        disable=not self.accelerator.is_main_process,
         
     | 
| 390 | 
         
            +
                    ):
         
     | 
| 391 | 
         
            +
                        # Get losses
         
     | 
| 392 | 
         
            +
                        total_loss, losses = self._train_step(batch)
         
     | 
| 393 | 
         
            +
                        self.batch_count += 1
         
     | 
| 394 | 
         
            +
             
     | 
| 395 | 
         
            +
                        # Log info
         
     | 
| 396 | 
         
            +
                        if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
         
     | 
| 397 | 
         
            +
                            self.accelerator.log(
         
     | 
| 398 | 
         
            +
                                {
         
     | 
| 399 | 
         
            +
                                    "Step/Learning Rate": (
         
     | 
| 400 | 
         
            +
                                        self.optimizer.schedulers["encoder"].get_last_lr()[0]
         
     | 
| 401 | 
         
            +
                                        if self.step != 0
         
     | 
| 402 | 
         
            +
                                        else 0
         
     | 
| 403 | 
         
            +
                                    )
         
     | 
| 404 | 
         
            +
                                },
         
     | 
| 405 | 
         
            +
                                step=self.step,
         
     | 
| 406 | 
         
            +
                            )
         
     | 
| 407 | 
         
            +
                            for key, _ in losses.items():
         
     | 
| 408 | 
         
            +
                                self.accelerator.log(
         
     | 
| 409 | 
         
            +
                                    {
         
     | 
| 410 | 
         
            +
                                        "Step/Train {} Loss".format(key): losses[key],
         
     | 
| 411 | 
         
            +
                                    },
         
     | 
| 412 | 
         
            +
                                    step=self.step,
         
     | 
| 413 | 
         
            +
                                )
         
     | 
| 414 | 
         
            +
             
     | 
| 415 | 
         
            +
                            if not epoch_losses:
         
     | 
| 416 | 
         
            +
                                epoch_losses = losses
         
     | 
| 417 | 
         
            +
                            else:
         
     | 
| 418 | 
         
            +
                                for key, value in losses.items():
         
     | 
| 419 | 
         
            +
                                    epoch_losses[key] += value
         
     | 
| 420 | 
         
            +
                            epoch_total_loss += total_loss
         
     | 
| 421 | 
         
            +
                            self.step += 1
         
     | 
| 422 | 
         
            +
             
     | 
| 423 | 
         
            +
                    # Get and log total losses
         
     | 
| 424 | 
         
            +
                    self.accelerator.wait_for_everyone()
         
     | 
| 425 | 
         
            +
                    epoch_total_loss = (
         
     | 
| 426 | 
         
            +
                        epoch_total_loss
         
     | 
| 427 | 
         
            +
                        / len(self.train_dataloader)
         
     | 
| 428 | 
         
            +
                        * self.cfg.train.gradient_accumulation_step
         
     | 
| 429 | 
         
            +
                    )
         
     | 
| 430 | 
         
            +
                    for key in epoch_losses.keys():
         
     | 
| 431 | 
         
            +
                        epoch_losses[key] = (
         
     | 
| 432 | 
         
            +
                            epoch_losses[key]
         
     | 
| 433 | 
         
            +
                            / len(self.train_dataloader)
         
     | 
| 434 | 
         
            +
                            * self.cfg.train.gradient_accumulation_step
         
     | 
| 435 | 
         
            +
                        )
         
     | 
| 436 | 
         
            +
                    return epoch_total_loss, epoch_losses
         
     | 
| 437 | 
         
            +
             
     | 
| 438 | 
         
            +
                def _train_step(self, data):
         
     | 
| 439 | 
         
            +
                    """Training forward step. Should return average loss of a sample over
         
     | 
| 440 | 
         
            +
                    one batch. Provoke ``_forward_step`` is recommended except for special case.
         
     | 
| 441 | 
         
            +
                    See ``_train_epoch`` for usage.
         
     | 
| 442 | 
         
            +
                    """
         
     | 
| 443 | 
         
            +
                    # Init losses
         
     | 
| 444 | 
         
            +
                    train_losses = {}
         
     | 
| 445 | 
         
            +
                    total_loss = 0
         
     | 
| 446 | 
         
            +
             
     | 
| 447 | 
         
            +
                    # Use input feature to get predictions
         
     | 
| 448 | 
         
            +
                    data = [b.to(self.accelerator.device, non_blocking=True) for b in data]
         
     | 
| 449 | 
         
            +
                    waves, mels, wave_lengths, mel_input_length = data
         
     | 
| 450 | 
         
            +
             
     | 
| 451 | 
         
            +
                    # extract semantic latent with w2v model
         
     | 
| 452 | 
         
            +
                    waves_16k = torchaudio.functional.resample(waves, 24000, 16000)
         
     | 
| 453 | 
         
            +
                    w2v_input = self.w2v_processor(
         
     | 
| 454 | 
         
            +
                        waves_16k, sampling_rate=16000, return_tensors="pt"
         
     | 
| 455 | 
         
            +
                    ).input_values.to(self.accelerator.device)
         
     | 
| 456 | 
         
            +
                    with torch.no_grad():
         
     | 
| 457 | 
         
            +
                        w2v_outputs = self.w2v_model(w2v_input.squeeze(0)).logits
         
     | 
| 458 | 
         
            +
                        predicted_ids = torch.argmax(w2v_outputs, dim=-1)
         
     | 
| 459 | 
         
            +
                        phone_ids = (
         
     | 
| 460 | 
         
            +
                            F.interpolate(
         
     | 
| 461 | 
         
            +
                                predicted_ids.unsqueeze(0).float(), mels.size(-1), mode="nearest"
         
     | 
| 462 | 
         
            +
                            )
         
     | 
| 463 | 
         
            +
                            .long()
         
     | 
| 464 | 
         
            +
                            .squeeze(0)
         
     | 
| 465 | 
         
            +
                        )
         
     | 
| 466 | 
         
            +
             
     | 
| 467 | 
         
            +
                    # get clips
         
     | 
| 468 | 
         
            +
                    mel_seg_len = min(
         
     | 
| 469 | 
         
            +
                        [int(mel_input_length.min().item()), self.cfg.train.max_frame_len]
         
     | 
| 470 | 
         
            +
                    )
         
     | 
| 471 | 
         
            +
             
     | 
| 472 | 
         
            +
                    gt_mel_seg = []
         
     | 
| 473 | 
         
            +
                    wav_seg = []
         
     | 
| 474 | 
         
            +
                    w2v_seg = []
         
     | 
| 475 | 
         
            +
             
     | 
| 476 | 
         
            +
                    for bib in range(len(mel_input_length)):
         
     | 
| 477 | 
         
            +
                        mel_length = int(mel_input_length[bib].item())
         
     | 
| 478 | 
         
            +
             
     | 
| 479 | 
         
            +
                        random_start = (
         
     | 
| 480 | 
         
            +
                            np.random.randint(0, mel_length - mel_seg_len)
         
     | 
| 481 | 
         
            +
                            if mel_length != mel_seg_len
         
     | 
| 482 | 
         
            +
                            else 0
         
     | 
| 483 | 
         
            +
                        )
         
     | 
| 484 | 
         
            +
                        gt_mel_seg.append(mels[bib, :, random_start : random_start + mel_seg_len])
         
     | 
| 485 | 
         
            +
             
     | 
| 486 | 
         
            +
                        # w2v_seg.append(w2v_latent[bib, :, random_start:random_start + mel_seg_len])
         
     | 
| 487 | 
         
            +
                        w2v_seg.append(phone_ids[bib, random_start : random_start + mel_seg_len])
         
     | 
| 488 | 
         
            +
             
     | 
| 489 | 
         
            +
                        y = waves[bib][random_start * 300 : (random_start + mel_seg_len) * 300]
         
     | 
| 490 | 
         
            +
             
     | 
| 491 | 
         
            +
                        wav_seg.append(y.to(self.accelerator.device))
         
     | 
| 492 | 
         
            +
             
     | 
| 493 | 
         
            +
                    gt_mel_seg = torch.stack(gt_mel_seg).detach()
         
     | 
| 494 | 
         
            +
             
     | 
| 495 | 
         
            +
                    wav_seg = torch.stack(wav_seg).float().detach().unsqueeze(1)
         
     | 
| 496 | 
         
            +
                    w2v_seg = torch.stack(w2v_seg).float().detach()
         
     | 
| 497 | 
         
            +
             
     | 
| 498 | 
         
            +
                    with torch.no_grad():
         
     | 
| 499 | 
         
            +
                        real_norm = log_norm(gt_mel_seg.unsqueeze(1)).squeeze(1).detach()
         
     | 
| 500 | 
         
            +
                        F0_real, _, _ = self.pitch_extractor(gt_mel_seg.unsqueeze(1))
         
     | 
| 501 | 
         
            +
             
     | 
| 502 | 
         
            +
                    # normalize f0
         
     | 
| 503 | 
         
            +
                    # Remove unvoiced frames (replace with -1)
         
     | 
| 504 | 
         
            +
                    gt_glob_f0s = []
         
     | 
| 505 | 
         
            +
                    f0_targets = []
         
     | 
| 506 | 
         
            +
                    for bib in range(len(F0_real)):
         
     | 
| 507 | 
         
            +
                        voiced_indices = F0_real[bib] > 5.0
         
     | 
| 508 | 
         
            +
                        f0_voiced = F0_real[bib][voiced_indices]
         
     | 
| 509 | 
         
            +
             
     | 
| 510 | 
         
            +
                        if len(f0_voiced) != 0:
         
     | 
| 511 | 
         
            +
                            # Convert to log scale
         
     | 
| 512 | 
         
            +
                            log_f0 = f0_voiced.log2()
         
     | 
| 513 | 
         
            +
             
     | 
| 514 | 
         
            +
                            # Calculate mean and standard deviation
         
     | 
| 515 | 
         
            +
                            mean_f0 = log_f0.mean()
         
     | 
| 516 | 
         
            +
                            std_f0 = log_f0.std()
         
     | 
| 517 | 
         
            +
             
     | 
| 518 | 
         
            +
                            # Normalize the F0 sequence
         
     | 
| 519 | 
         
            +
                            normalized_f0 = (log_f0 - mean_f0) / std_f0
         
     | 
| 520 | 
         
            +
             
     | 
| 521 | 
         
            +
                            # Create the normalized F0 sequence with unvoiced frames
         
     | 
| 522 | 
         
            +
                            normalized_sequence = torch.zeros_like(F0_real[bib])
         
     | 
| 523 | 
         
            +
                            normalized_sequence[voiced_indices] = normalized_f0
         
     | 
| 524 | 
         
            +
                            normalized_sequence[~voiced_indices] = (
         
     | 
| 525 | 
         
            +
                                -10
         
     | 
| 526 | 
         
            +
                            )  # Assign -10 to unvoiced frames
         
     | 
| 527 | 
         
            +
             
     | 
| 528 | 
         
            +
                            gt_glob_f0s.append(mean_f0)
         
     | 
| 529 | 
         
            +
                        else:
         
     | 
| 530 | 
         
            +
                            normalized_sequence = torch.zeros_like(F0_real[bib]) - 10.0
         
     | 
| 531 | 
         
            +
                            gt_glob_f0s.append(torch.tensor(0.0).to(self.accelerator.device))
         
     | 
| 532 | 
         
            +
             
     | 
| 533 | 
         
            +
                        # f0_targets.append(normalized_sequence[single_side_context // 200:-single_side_context // 200])
         
     | 
| 534 | 
         
            +
                        f0_targets.append(normalized_sequence)
         
     | 
| 535 | 
         
            +
                    f0_targets = torch.stack(f0_targets).to(self.accelerator.device)
         
     | 
| 536 | 
         
            +
                    # fill nan with -10
         
     | 
| 537 | 
         
            +
                    f0_targets[torch.isnan(f0_targets)] = -10.0
         
     | 
| 538 | 
         
            +
                    # fill inf with -10
         
     | 
| 539 | 
         
            +
                    f0_targets[torch.isinf(f0_targets)] = -10.0
         
     | 
| 540 | 
         
            +
                    # if frame_rate not equal to 80, interpolate f0 from frame rate of 80 to target frame rate
         
     | 
| 541 | 
         
            +
                    if self.cfg.preprocess_params.frame_rate != 80:
         
     | 
| 542 | 
         
            +
                        f0_targets = F.interpolate(
         
     | 
| 543 | 
         
            +
                            f0_targets.unsqueeze(1),
         
     | 
| 544 | 
         
            +
                            mel_seg_len // 80 * self.cfg.preprocess_params.frame_rate,
         
     | 
| 545 | 
         
            +
                            mode="nearest",
         
     | 
| 546 | 
         
            +
                        ).squeeze(1)
         
     | 
| 547 | 
         
            +
                        w2v_seg = F.interpolate(
         
     | 
| 548 | 
         
            +
                            w2v_seg,
         
     | 
| 549 | 
         
            +
                            mel_seg_len // 80 * self.cfg.preprocess_params.frame_rate,
         
     | 
| 550 | 
         
            +
                            mode="nearest",
         
     | 
| 551 | 
         
            +
                        )
         
     | 
| 552 | 
         
            +
             
     | 
| 553 | 
         
            +
                    wav_seg_input = wav_seg
         
     | 
| 554 | 
         
            +
                    wav_seg_target = wav_seg
         
     | 
| 555 | 
         
            +
             
     | 
| 556 | 
         
            +
                    z = self.model.encoder(wav_seg_input)
         
     | 
| 557 | 
         
            +
                    z, quantized, commitment_loss, codebook_loss, timbre = self.model.quantizer(
         
     | 
| 558 | 
         
            +
                        z, wav_seg_input, n_c=2, full_waves=waves, wave_lens=wave_lengths
         
     | 
| 559 | 
         
            +
                    )
         
     | 
| 560 | 
         
            +
                    preds, rev_preds = self.model.fa_predictors(quantized, timbre)
         
     | 
| 561 | 
         
            +
             
     | 
| 562 | 
         
            +
                    pred_wave = self.model.decoder(z)
         
     | 
| 563 | 
         
            +
             
     | 
| 564 | 
         
            +
                    len_diff = wav_seg_target.size(-1) - pred_wave.size(-1)
         
     | 
| 565 | 
         
            +
                    if len_diff > 0:
         
     | 
| 566 | 
         
            +
                        wav_seg_target = wav_seg_target[..., len_diff // 2 : -len_diff // 2]
         
     | 
| 567 | 
         
            +
             
     | 
| 568 | 
         
            +
                    # discriminator loss
         
     | 
| 569 | 
         
            +
                    d_fake = self.model.discriminator(pred_wave.detach())
         
     | 
| 570 | 
         
            +
                    d_real = self.model.discriminator(wav_seg_target)
         
     | 
| 571 | 
         
            +
                    loss_d = 0
         
     | 
| 572 | 
         
            +
                    for x_fake, x_real in zip(d_fake, d_real):
         
     | 
| 573 | 
         
            +
                        loss_d += torch.mean(x_fake[-1] ** 2)
         
     | 
| 574 | 
         
            +
                        loss_d += torch.mean((1 - x_real[-1]) ** 2)
         
     | 
| 575 | 
         
            +
             
     | 
| 576 | 
         
            +
                    self.optimizer.zero_grad()
         
     | 
| 577 | 
         
            +
                    self.accelerator.backward(loss_d)
         
     | 
| 578 | 
         
            +
                    grad_norm_d = torch.nn.utils.clip_grad_norm_(
         
     | 
| 579 | 
         
            +
                        self.model.discriminator.parameters(), 10.0
         
     | 
| 580 | 
         
            +
                    )
         
     | 
| 581 | 
         
            +
                    self.optimizer.step("discriminator")
         
     | 
| 582 | 
         
            +
                    self.optimizer.scheduler(key="discriminator")
         
     | 
| 583 | 
         
            +
             
     | 
| 584 | 
         
            +
                    # generator loss
         
     | 
| 585 | 
         
            +
                    signal = AudioSignal(wav_seg_target, sample_rate=24000)
         
     | 
| 586 | 
         
            +
                    recons = AudioSignal(pred_wave, sample_rate=24000)
         
     | 
| 587 | 
         
            +
                    stft_loss = self.criterions["stft"](recons, signal)
         
     | 
| 588 | 
         
            +
                    mel_loss = self.criterions["mel"](recons, signal)
         
     | 
| 589 | 
         
            +
                    waveform_loss = self.criterions["l1"](recons, signal)
         
     | 
| 590 | 
         
            +
             
     | 
| 591 | 
         
            +
                    d_fake = self.model.discriminator(pred_wave)
         
     | 
| 592 | 
         
            +
                    d_real = self.model.discriminator(wav_seg_target)
         
     | 
| 593 | 
         
            +
             
     | 
| 594 | 
         
            +
                    loss_g = 0
         
     | 
| 595 | 
         
            +
                    for x_fake in d_fake:
         
     | 
| 596 | 
         
            +
                        loss_g += torch.mean((1 - x_fake[-1]) ** 2)
         
     | 
| 597 | 
         
            +
             
     | 
| 598 | 
         
            +
                    loss_feature = 0
         
     | 
| 599 | 
         
            +
             
     | 
| 600 | 
         
            +
                    for i in range(len(d_fake)):
         
     | 
| 601 | 
         
            +
                        for j in range(len(d_fake[i]) - 1):
         
     | 
| 602 | 
         
            +
                            loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach())
         
     | 
| 603 | 
         
            +
             
     | 
| 604 | 
         
            +
                    pred_f0, pred_uv = preds["f0"], preds["uv"]
         
     | 
| 605 | 
         
            +
                    rev_pred_f0, rev_pred_uv = rev_preds["rev_f0"], rev_preds["rev_uv"]
         
     | 
| 606 | 
         
            +
             
     | 
| 607 | 
         
            +
                    common_min_size = min(pred_f0.size(-2), f0_targets.size(-1))
         
     | 
| 608 | 
         
            +
                    f0_targets = f0_targets[..., :common_min_size]
         
     | 
| 609 | 
         
            +
                    real_norm = real_norm[..., :common_min_size]
         
     | 
| 610 | 
         
            +
             
     | 
| 611 | 
         
            +
                    f0_loss = F.smooth_l1_loss(
         
     | 
| 612 | 
         
            +
                        f0_targets, pred_f0.squeeze(-1)[..., :common_min_size]
         
     | 
| 613 | 
         
            +
                    )
         
     | 
| 614 | 
         
            +
                    uv_loss = F.smooth_l1_loss(
         
     | 
| 615 | 
         
            +
                        real_norm, pred_uv.squeeze(-1)[..., :common_min_size]
         
     | 
| 616 | 
         
            +
                    )
         
     | 
| 617 | 
         
            +
                    rev_f0_loss = (
         
     | 
| 618 | 
         
            +
                        F.smooth_l1_loss(f0_targets, rev_pred_f0.squeeze(-1)[..., :common_min_size])
         
     | 
| 619 | 
         
            +
                        if rev_pred_f0 is not None
         
     | 
| 620 | 
         
            +
                        else torch.FloatTensor([0]).to(self.accelerator.device)
         
     | 
| 621 | 
         
            +
                    )
         
     | 
| 622 | 
         
            +
                    rev_uv_loss = (
         
     | 
| 623 | 
         
            +
                        F.smooth_l1_loss(real_norm, rev_pred_uv.squeeze(-1)[..., :common_min_size])
         
     | 
| 624 | 
         
            +
                        if rev_pred_uv is not None
         
     | 
| 625 | 
         
            +
                        else torch.FloatTensor([0]).to(self.accelerator.device)
         
     | 
| 626 | 
         
            +
                    )
         
     | 
| 627 | 
         
            +
             
     | 
| 628 | 
         
            +
                    tot_f0_loss = f0_loss + rev_f0_loss
         
     | 
| 629 | 
         
            +
                    tot_uv_loss = uv_loss + rev_uv_loss
         
     | 
| 630 | 
         
            +
             
     | 
| 631 | 
         
            +
                    pred_content = preds["content"]
         
     | 
| 632 | 
         
            +
                    rev_pred_content = rev_preds["rev_content"]
         
     | 
| 633 | 
         
            +
             
     | 
| 634 | 
         
            +
                    target_content_latents = w2v_seg[..., :common_min_size]
         
     | 
| 635 | 
         
            +
             
     | 
| 636 | 
         
            +
                    content_loss = self.criterions["content"](
         
     | 
| 637 | 
         
            +
                        pred_content.transpose(1, 2)[..., :common_min_size],
         
     | 
| 638 | 
         
            +
                        target_content_latents.long(),
         
     | 
| 639 | 
         
            +
                    )
         
     | 
| 640 | 
         
            +
                    rev_content_loss = (
         
     | 
| 641 | 
         
            +
                        self.criterions["content"](
         
     | 
| 642 | 
         
            +
                            rev_pred_content.transpose(1, 2)[..., :common_min_size],
         
     | 
| 643 | 
         
            +
                            target_content_latents.long(),
         
     | 
| 644 | 
         
            +
                        )
         
     | 
| 645 | 
         
            +
                        if rev_pred_content is not None
         
     | 
| 646 | 
         
            +
                        else torch.FloatTensor([0]).to(self.accelerator.device)
         
     | 
| 647 | 
         
            +
                    )
         
     | 
| 648 | 
         
            +
             
     | 
| 649 | 
         
            +
                    tot_content_loss = content_loss + rev_content_loss
         
     | 
| 650 | 
         
            +
             
     | 
| 651 | 
         
            +
                    if self.speaker_model is not None:
         
     | 
| 652 | 
         
            +
                        spk_logits = torch.cat(
         
     | 
| 653 | 
         
            +
                            [
         
     | 
| 654 | 
         
            +
                                self.speaker_model.infer_segment(w16.cpu()[..., :wl])[1]
         
     | 
| 655 | 
         
            +
                                for w16, wl in zip(waves_16k, wave_lengths)
         
     | 
| 656 | 
         
            +
                            ],
         
     | 
| 657 | 
         
            +
                            dim=0,
         
     | 
| 658 | 
         
            +
                        )
         
     | 
| 659 | 
         
            +
                        spk_labels = spk_logits.argmax(dim=-1)
         
     | 
| 660 | 
         
            +
                    else:
         
     | 
| 661 | 
         
            +
                        spk_labels = torch.zeros([len(waves_16k)], dtype=torch.long).to(
         
     | 
| 662 | 
         
            +
                            self.accelerator.device
         
     | 
| 663 | 
         
            +
                        )
         
     | 
| 664 | 
         
            +
             
     | 
| 665 | 
         
            +
                    spk_pred_logits = preds["timbre"]
         
     | 
| 666 | 
         
            +
                    spk_loss = F.cross_entropy(spk_pred_logits, spk_labels)
         
     | 
| 667 | 
         
            +
                    x_spk_pred_logits = rev_preds["x_timbre"]
         
     | 
| 668 | 
         
            +
             
     | 
| 669 | 
         
            +
                    x_spk_loss = (
         
     | 
| 670 | 
         
            +
                        F.cross_entropy(x_spk_pred_logits, spk_labels)
         
     | 
| 671 | 
         
            +
                        if x_spk_pred_logits is not None
         
     | 
| 672 | 
         
            +
                        else torch.FloatTensor([0]).to(self.accelerator.device)
         
     | 
| 673 | 
         
            +
                    )
         
     | 
| 674 | 
         
            +
             
     | 
| 675 | 
         
            +
                    tot_spk_loss = spk_loss + x_spk_loss
         
     | 
| 676 | 
         
            +
             
     | 
| 677 | 
         
            +
                    loss_gen_all = (
         
     | 
| 678 | 
         
            +
                        mel_loss * 15.0
         
     | 
| 679 | 
         
            +
                        + loss_feature * 1.0
         
     | 
| 680 | 
         
            +
                        + loss_g * 1.0
         
     | 
| 681 | 
         
            +
                        + commitment_loss * 0.25
         
     | 
| 682 | 
         
            +
                        + codebook_loss * 1.0
         
     | 
| 683 | 
         
            +
                        + tot_f0_loss * 1.0
         
     | 
| 684 | 
         
            +
                        + tot_uv_loss * 1.0
         
     | 
| 685 | 
         
            +
                        + tot_content_loss * 5.0
         
     | 
| 686 | 
         
            +
                        + tot_spk_loss * 5.0
         
     | 
| 687 | 
         
            +
                    )
         
     | 
| 688 | 
         
            +
             
     | 
| 689 | 
         
            +
                    self.optimizer.zero_grad()
         
     | 
| 690 | 
         
            +
                    self.accelerator.backward(loss_gen_all)
         
     | 
| 691 | 
         
            +
             
     | 
| 692 | 
         
            +
                    with torch.no_grad():
         
     | 
| 693 | 
         
            +
                        total_loss = loss_gen_all.item()
         
     | 
| 694 | 
         
            +
                        train_losses["stft"] = stft_loss.item()
         
     | 
| 695 | 
         
            +
                        train_losses["mel"] = mel_loss.item()
         
     | 
| 696 | 
         
            +
                        train_losses["l1"] = waveform_loss.item()
         
     | 
| 697 | 
         
            +
                        train_losses["f0"] = f0_loss.item()
         
     | 
| 698 | 
         
            +
                        train_losses["uv"] = uv_loss.item()
         
     | 
| 699 | 
         
            +
                        train_losses["content"] = content_loss.item()
         
     | 
| 700 | 
         
            +
                        train_losses["speaker"] = spk_loss.item()
         
     | 
| 701 | 
         
            +
                        train_losses["rev_f0"] = rev_f0_loss.item()
         
     | 
| 702 | 
         
            +
                        train_losses["rev_uv"] = rev_uv_loss.item()
         
     | 
| 703 | 
         
            +
                        train_losses["rev_content"] = rev_content_loss.item()
         
     | 
| 704 | 
         
            +
                        train_losses["rev_speaker"] = x_spk_loss.item()
         
     | 
| 705 | 
         
            +
             
     | 
| 706 | 
         
            +
                        train_losses["feature"] = loss_feature.item()
         
     | 
| 707 | 
         
            +
                        train_losses["generator"] = loss_g.item()
         
     | 
| 708 | 
         
            +
                        train_losses["commitment"] = commitment_loss.item()
         
     | 
| 709 | 
         
            +
                        train_losses["codebook"] = codebook_loss.item()
         
     | 
| 710 | 
         
            +
             
     | 
| 711 | 
         
            +
                        # discriminators
         
     | 
| 712 | 
         
            +
                        train_losses["discriminator"] = loss_d.item()
         
     | 
| 713 | 
         
            +
             
     | 
| 714 | 
         
            +
                    return total_loss, train_losses
         
     | 
| 715 | 
         
            +
             
     | 
| 716 | 
         
            +
                def _inference(self, eval_wave):
         
     | 
| 717 | 
         
            +
                    """Inference during training for test audios."""
         
     | 
| 718 | 
         
            +
                    z = self.model.encoder(
         
     | 
| 719 | 
         
            +
                        eval_wave[None, None, ...].to(self.accelerator.device).float()
         
     | 
| 720 | 
         
            +
                    )
         
     | 
| 721 | 
         
            +
                    z, quantized, commitment_loss, codebook_loss, timbre = self.model.quantizer(
         
     | 
| 722 | 
         
            +
                        z, eval_wave[None, None, ...], n_c=self.cfg.model_params.n_c_codebooks
         
     | 
| 723 | 
         
            +
                    )
         
     | 
| 724 | 
         
            +
                    full_pred_wave = self.model.decoder(z)
         
     | 
| 725 | 
         
            +
                    return full_pred_wave[0]
         
     | 
| 726 | 
         
            +
             
     | 
| 727 | 
         
            +
                def _load_model(self, checkpoint_path=None, resume_type="resume"):
         
     | 
| 728 | 
         
            +
                    """Load model from checkpoint. If checkpoint_path is None, it will
         
     | 
| 729 | 
         
            +
                    load the latest checkpoint in checkpoint_dir. If checkpoint_path is not
         
     | 
| 730 | 
         
            +
                    None, it will load the checkpoint specified by checkpoint_path. **Only use this
         
     | 
| 731 | 
         
            +
                    method after** ``accelerator.prepare()``.
         
     | 
| 732 | 
         
            +
                    """
         
     | 
| 733 | 
         
            +
                    if resume_type == "resume":
         
     | 
| 734 | 
         
            +
                        if checkpoint_path is None:
         
     | 
| 735 | 
         
            +
                            available_checkpoints = glob.glob(
         
     | 
| 736 | 
         
            +
                                os.path.join(self.checkpoint_dir, "FAcodc_epoch_*_step_*.pth")
         
     | 
| 737 | 
         
            +
                            )
         
     | 
| 738 | 
         
            +
                            # find the checkpoint that has the highest step number
         
     | 
| 739 | 
         
            +
                            latest_checkpoint = max(
         
     | 
| 740 | 
         
            +
                                available_checkpoints,
         
     | 
| 741 | 
         
            +
                                key=lambda x: int(x.split("_")[-1].split(".")[0]),
         
     | 
| 742 | 
         
            +
                            )
         
     | 
| 743 | 
         
            +
                            earliest_checkpoint = min(
         
     | 
| 744 | 
         
            +
                                available_checkpoints,
         
     | 
| 745 | 
         
            +
                                key=lambda x: int(x.split("_")[-1].split(".")[0]),
         
     | 
| 746 | 
         
            +
                            )
         
     | 
| 747 | 
         
            +
                            # delete the earliest checkpoint
         
     | 
| 748 | 
         
            +
                            if (
         
     | 
| 749 | 
         
            +
                                earliest_checkpoint != latest_checkpoint
         
     | 
| 750 | 
         
            +
                                and self.accelerator.is_main_process
         
     | 
| 751 | 
         
            +
                                and len(available_checkpoints) > 4
         
     | 
| 752 | 
         
            +
                            ):
         
     | 
| 753 | 
         
            +
                                os.remove(earliest_checkpoint)
         
     | 
| 754 | 
         
            +
                                print(f"Removed {earliest_checkpoint}")
         
     | 
| 755 | 
         
            +
                        else:
         
     | 
| 756 | 
         
            +
                            latest_checkpoint = checkpoint_path
         
     | 
| 757 | 
         
            +
             
     | 
| 758 | 
         
            +
                        self.model, self.optimizer, self.epoch, self.step = load_checkpoint(
         
     | 
| 759 | 
         
            +
                            self.model,
         
     | 
| 760 | 
         
            +
                            self.optimizer,
         
     | 
| 761 | 
         
            +
                            latest_checkpoint,
         
     | 
| 762 | 
         
            +
                            load_only_params=False,
         
     | 
| 763 | 
         
            +
                            ignore_modules=[],
         
     | 
| 764 | 
         
            +
                            is_distributed=self.accelerator.num_processes > 1,
         
     | 
| 765 | 
         
            +
                        )
         
     | 
| 766 | 
         
            +
             
     | 
| 767 | 
         
            +
                    else:
         
     | 
| 768 | 
         
            +
                        raise ValueError("Invalid resume type")
         
     | 
| 769 | 
         
            +
                    return checkpoint_path
         
     | 
| 770 | 
         
            +
             
     | 
| 771 | 
         
            +
                def _count_parameters(self):
         
     | 
| 772 | 
         
            +
                    total_num = sum(
         
     | 
| 773 | 
         
            +
                        sum(p.numel() for p in self.model[key].parameters()) for key in self.model
         
     | 
| 774 | 
         
            +
                    )
         
     | 
| 775 | 
         
            +
                    # trainable_num = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
         
     | 
| 776 | 
         
            +
                    return total_num
         
     | 
    	
        models/codec/facodec/modules/JDC/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1 @@ 
     | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
             
     | 
    	
        models/codec/facodec/modules/JDC/bst.t7
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:54dc94364b97e18ac1dfa6287714ed121248cfaac4cfd39d061c6e0a089ef169
         
     | 
| 3 | 
         
            +
            size 21029926
         
     | 
    	
        models/codec/facodec/modules/JDC/model.py
    ADDED
    
    | 
         @@ -0,0 +1,219 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2023 Amphion.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            # This code is borrowed from https://github.com/yl4579/PitchExtractor/blob/main/model.py
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            """
         
     | 
| 9 | 
         
            +
            Implementation of model from:
         
     | 
| 10 | 
         
            +
            Kum et al. - "Joint Detection and Classification of Singing Voice Melody Using
         
     | 
| 11 | 
         
            +
            Convolutional Recurrent Neural Networks" (2019)
         
     | 
| 12 | 
         
            +
            Link: https://www.semanticscholar.org/paper/Joint-Detection-and-Classification-of-Singing-Voice-Kum-Nam/60a2ad4c7db43bace75805054603747fcd062c0d
         
     | 
| 13 | 
         
            +
            """
         
     | 
| 14 | 
         
            +
            import torch
         
     | 
| 15 | 
         
            +
            from torch import nn
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            class JDCNet(nn.Module):
         
     | 
| 19 | 
         
            +
                """
         
     | 
| 20 | 
         
            +
                Joint Detection and Classification Network model for singing voice melody.
         
     | 
| 21 | 
         
            +
                """
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
                def __init__(self, num_class=722, seq_len=31, leaky_relu_slope=0.01):
         
     | 
| 24 | 
         
            +
                    super().__init__()
         
     | 
| 25 | 
         
            +
                    self.num_class = num_class
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                    # input = (b, 1, 31, 513), b = batch size
         
     | 
| 28 | 
         
            +
                    self.conv_block = nn.Sequential(
         
     | 
| 29 | 
         
            +
                        nn.Conv2d(
         
     | 
| 30 | 
         
            +
                            in_channels=1, out_channels=64, kernel_size=3, padding=1, bias=False
         
     | 
| 31 | 
         
            +
                        ),  # out: (b, 64, 31, 513)
         
     | 
| 32 | 
         
            +
                        nn.BatchNorm2d(num_features=64),
         
     | 
| 33 | 
         
            +
                        nn.LeakyReLU(leaky_relu_slope, inplace=True),
         
     | 
| 34 | 
         
            +
                        nn.Conv2d(64, 64, 3, padding=1, bias=False),  # (b, 64, 31, 513)
         
     | 
| 35 | 
         
            +
                    )
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                    # res blocks
         
     | 
| 38 | 
         
            +
                    self.res_block1 = ResBlock(
         
     | 
| 39 | 
         
            +
                        in_channels=64, out_channels=128
         
     | 
| 40 | 
         
            +
                    )  # (b, 128, 31, 128)
         
     | 
| 41 | 
         
            +
                    self.res_block2 = ResBlock(
         
     | 
| 42 | 
         
            +
                        in_channels=128, out_channels=192
         
     | 
| 43 | 
         
            +
                    )  # (b, 192, 31, 32)
         
     | 
| 44 | 
         
            +
                    self.res_block3 = ResBlock(in_channels=192, out_channels=256)  # (b, 256, 31, 8)
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                    # pool block
         
     | 
| 47 | 
         
            +
                    self.pool_block = nn.Sequential(
         
     | 
| 48 | 
         
            +
                        nn.BatchNorm2d(num_features=256),
         
     | 
| 49 | 
         
            +
                        nn.LeakyReLU(leaky_relu_slope, inplace=True),
         
     | 
| 50 | 
         
            +
                        nn.MaxPool2d(kernel_size=(1, 4)),  # (b, 256, 31, 2)
         
     | 
| 51 | 
         
            +
                        nn.Dropout(p=0.2),
         
     | 
| 52 | 
         
            +
                    )
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                    # maxpool layers (for auxiliary network inputs)
         
     | 
| 55 | 
         
            +
                    # in = (b, 128, 31, 513) from conv_block, out = (b, 128, 31, 2)
         
     | 
| 56 | 
         
            +
                    self.maxpool1 = nn.MaxPool2d(kernel_size=(1, 40))
         
     | 
| 57 | 
         
            +
                    # in = (b, 128, 31, 128) from res_block1, out = (b, 128, 31, 2)
         
     | 
| 58 | 
         
            +
                    self.maxpool2 = nn.MaxPool2d(kernel_size=(1, 20))
         
     | 
| 59 | 
         
            +
                    # in = (b, 128, 31, 32) from res_block2, out = (b, 128, 31, 2)
         
     | 
| 60 | 
         
            +
                    self.maxpool3 = nn.MaxPool2d(kernel_size=(1, 10))
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                    # in = (b, 640, 31, 2), out = (b, 256, 31, 2)
         
     | 
| 63 | 
         
            +
                    self.detector_conv = nn.Sequential(
         
     | 
| 64 | 
         
            +
                        nn.Conv2d(640, 256, 1, bias=False),
         
     | 
| 65 | 
         
            +
                        nn.BatchNorm2d(256),
         
     | 
| 66 | 
         
            +
                        nn.LeakyReLU(leaky_relu_slope, inplace=True),
         
     | 
| 67 | 
         
            +
                        nn.Dropout(p=0.2),
         
     | 
| 68 | 
         
            +
                    )
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                    # input: (b, 31, 512) - resized from (b, 256, 31, 2)
         
     | 
| 71 | 
         
            +
                    self.bilstm_classifier = nn.LSTM(
         
     | 
| 72 | 
         
            +
                        input_size=512, hidden_size=256, batch_first=True, bidirectional=True
         
     | 
| 73 | 
         
            +
                    )  # (b, 31, 512)
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
                    # input: (b, 31, 512) - resized from (b, 256, 31, 2)
         
     | 
| 76 | 
         
            +
                    self.bilstm_detector = nn.LSTM(
         
     | 
| 77 | 
         
            +
                        input_size=512, hidden_size=256, batch_first=True, bidirectional=True
         
     | 
| 78 | 
         
            +
                    )  # (b, 31, 512)
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
                    # input: (b * 31, 512)
         
     | 
| 81 | 
         
            +
                    self.classifier = nn.Linear(
         
     | 
| 82 | 
         
            +
                        in_features=512, out_features=self.num_class
         
     | 
| 83 | 
         
            +
                    )  # (b * 31, num_class)
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
                    # input: (b * 31, 512)
         
     | 
| 86 | 
         
            +
                    self.detector = nn.Linear(
         
     | 
| 87 | 
         
            +
                        in_features=512, out_features=2
         
     | 
| 88 | 
         
            +
                    )  # (b * 31, 2) - binary classifier
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                    # initialize weights
         
     | 
| 91 | 
         
            +
                    self.apply(self.init_weights)
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
                def get_feature_GAN(self, x):
         
     | 
| 94 | 
         
            +
                    seq_len = x.shape[-2]
         
     | 
| 95 | 
         
            +
                    x = x.float().transpose(-1, -2)
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                    convblock_out = self.conv_block(x)
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
                    resblock1_out = self.res_block1(convblock_out)
         
     | 
| 100 | 
         
            +
                    resblock2_out = self.res_block2(resblock1_out)
         
     | 
| 101 | 
         
            +
                    resblock3_out = self.res_block3(resblock2_out)
         
     | 
| 102 | 
         
            +
                    poolblock_out = self.pool_block[0](resblock3_out)
         
     | 
| 103 | 
         
            +
                    poolblock_out = self.pool_block[1](poolblock_out)
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
                    return poolblock_out.transpose(-1, -2)
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
                def get_feature(self, x):
         
     | 
| 108 | 
         
            +
                    seq_len = x.shape[-2]
         
     | 
| 109 | 
         
            +
                    x = x.float().transpose(-1, -2)
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                    convblock_out = self.conv_block(x)
         
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
                    resblock1_out = self.res_block1(convblock_out)
         
     | 
| 114 | 
         
            +
                    resblock2_out = self.res_block2(resblock1_out)
         
     | 
| 115 | 
         
            +
                    resblock3_out = self.res_block3(resblock2_out)
         
     | 
| 116 | 
         
            +
                    poolblock_out = self.pool_block[0](resblock3_out)
         
     | 
| 117 | 
         
            +
                    poolblock_out = self.pool_block[1](poolblock_out)
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                    return self.pool_block[2](poolblock_out)
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
                def forward(self, x):
         
     | 
| 122 | 
         
            +
                    """
         
     | 
| 123 | 
         
            +
                    Returns:
         
     | 
| 124 | 
         
            +
                        classification_prediction, detection_prediction
         
     | 
| 125 | 
         
            +
                        sizes: (b, 31, 722), (b, 31, 2)
         
     | 
| 126 | 
         
            +
                    """
         
     | 
| 127 | 
         
            +
                    ###############################
         
     | 
| 128 | 
         
            +
                    # forward pass for classifier #
         
     | 
| 129 | 
         
            +
                    ###############################
         
     | 
| 130 | 
         
            +
                    seq_len = x.shape[-1]
         
     | 
| 131 | 
         
            +
                    x = x.float().transpose(-1, -2)
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                    convblock_out = self.conv_block(x)
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
                    resblock1_out = self.res_block1(convblock_out)
         
     | 
| 136 | 
         
            +
                    resblock2_out = self.res_block2(resblock1_out)
         
     | 
| 137 | 
         
            +
                    resblock3_out = self.res_block3(resblock2_out)
         
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
                    poolblock_out = self.pool_block[0](resblock3_out)
         
     | 
| 140 | 
         
            +
                    poolblock_out = self.pool_block[1](poolblock_out)
         
     | 
| 141 | 
         
            +
                    GAN_feature = poolblock_out.transpose(-1, -2)
         
     | 
| 142 | 
         
            +
                    poolblock_out = self.pool_block[2](poolblock_out)
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
                    # (b, 256, 31, 2) => (b, 31, 256, 2) => (b, 31, 512)
         
     | 
| 145 | 
         
            +
                    classifier_out = (
         
     | 
| 146 | 
         
            +
                        poolblock_out.permute(0, 2, 1, 3).contiguous().view((-1, seq_len, 512))
         
     | 
| 147 | 
         
            +
                    )
         
     | 
| 148 | 
         
            +
                    classifier_out, _ = self.bilstm_classifier(
         
     | 
| 149 | 
         
            +
                        classifier_out
         
     | 
| 150 | 
         
            +
                    )  # ignore the hidden states
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
                    classifier_out = classifier_out.contiguous().view((-1, 512))  # (b * 31, 512)
         
     | 
| 153 | 
         
            +
                    classifier_out = self.classifier(classifier_out)
         
     | 
| 154 | 
         
            +
                    classifier_out = classifier_out.view(
         
     | 
| 155 | 
         
            +
                        (-1, seq_len, self.num_class)
         
     | 
| 156 | 
         
            +
                    )  # (b, 31, num_class)
         
     | 
| 157 | 
         
            +
             
     | 
| 158 | 
         
            +
                    # sizes: (b, 31, 722), (b, 31, 2)
         
     | 
| 159 | 
         
            +
                    # classifier output consists of predicted pitch classes per frame
         
     | 
| 160 | 
         
            +
                    # detector output consists of: (isvoice, notvoice) estimates per frame
         
     | 
| 161 | 
         
            +
                    return torch.abs(classifier_out.squeeze(-1)), GAN_feature, poolblock_out
         
     | 
| 162 | 
         
            +
             
     | 
| 163 | 
         
            +
                @staticmethod
         
     | 
| 164 | 
         
            +
                def init_weights(m):
         
     | 
| 165 | 
         
            +
                    if isinstance(m, nn.Linear):
         
     | 
| 166 | 
         
            +
                        nn.init.kaiming_uniform_(m.weight)
         
     | 
| 167 | 
         
            +
                        if m.bias is not None:
         
     | 
| 168 | 
         
            +
                            nn.init.constant_(m.bias, 0)
         
     | 
| 169 | 
         
            +
                    elif isinstance(m, nn.Conv2d):
         
     | 
| 170 | 
         
            +
                        nn.init.xavier_normal_(m.weight)
         
     | 
| 171 | 
         
            +
                    elif isinstance(m, nn.LSTM) or isinstance(m, nn.LSTMCell):
         
     | 
| 172 | 
         
            +
                        for p in m.parameters():
         
     | 
| 173 | 
         
            +
                            if p.data is None:
         
     | 
| 174 | 
         
            +
                                continue
         
     | 
| 175 | 
         
            +
             
     | 
| 176 | 
         
            +
                            if len(p.shape) >= 2:
         
     | 
| 177 | 
         
            +
                                nn.init.orthogonal_(p.data)
         
     | 
| 178 | 
         
            +
                            else:
         
     | 
| 179 | 
         
            +
                                nn.init.normal_(p.data)
         
     | 
| 180 | 
         
            +
             
     | 
| 181 | 
         
            +
             
     | 
| 182 | 
         
            +
            class ResBlock(nn.Module):
         
     | 
| 183 | 
         
            +
                def __init__(self, in_channels: int, out_channels: int, leaky_relu_slope=0.01):
         
     | 
| 184 | 
         
            +
                    super().__init__()
         
     | 
| 185 | 
         
            +
                    self.downsample = in_channels != out_channels
         
     | 
| 186 | 
         
            +
             
     | 
| 187 | 
         
            +
                    # BN / LReLU / MaxPool layer before the conv layer - see Figure 1b in the paper
         
     | 
| 188 | 
         
            +
                    self.pre_conv = nn.Sequential(
         
     | 
| 189 | 
         
            +
                        nn.BatchNorm2d(num_features=in_channels),
         
     | 
| 190 | 
         
            +
                        nn.LeakyReLU(leaky_relu_slope, inplace=True),
         
     | 
| 191 | 
         
            +
                        nn.MaxPool2d(kernel_size=(1, 2)),  # apply downsampling on the y axis only
         
     | 
| 192 | 
         
            +
                    )
         
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
                    # conv layers
         
     | 
| 195 | 
         
            +
                    self.conv = nn.Sequential(
         
     | 
| 196 | 
         
            +
                        nn.Conv2d(
         
     | 
| 197 | 
         
            +
                            in_channels=in_channels,
         
     | 
| 198 | 
         
            +
                            out_channels=out_channels,
         
     | 
| 199 | 
         
            +
                            kernel_size=3,
         
     | 
| 200 | 
         
            +
                            padding=1,
         
     | 
| 201 | 
         
            +
                            bias=False,
         
     | 
| 202 | 
         
            +
                        ),
         
     | 
| 203 | 
         
            +
                        nn.BatchNorm2d(out_channels),
         
     | 
| 204 | 
         
            +
                        nn.LeakyReLU(leaky_relu_slope, inplace=True),
         
     | 
| 205 | 
         
            +
                        nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
         
     | 
| 206 | 
         
            +
                    )
         
     | 
| 207 | 
         
            +
             
     | 
| 208 | 
         
            +
                    # 1 x 1 convolution layer to match the feature dimensions
         
     | 
| 209 | 
         
            +
                    self.conv1by1 = None
         
     | 
| 210 | 
         
            +
                    if self.downsample:
         
     | 
| 211 | 
         
            +
                        self.conv1by1 = nn.Conv2d(in_channels, out_channels, 1, bias=False)
         
     | 
| 212 | 
         
            +
             
     | 
| 213 | 
         
            +
                def forward(self, x):
         
     | 
| 214 | 
         
            +
                    x = self.pre_conv(x)
         
     | 
| 215 | 
         
            +
                    if self.downsample:
         
     | 
| 216 | 
         
            +
                        x = self.conv(x) + self.conv1by1(x)
         
     | 
| 217 | 
         
            +
                    else:
         
     | 
| 218 | 
         
            +
                        x = self.conv(x) + x
         
     | 
| 219 | 
         
            +
                    return x
         
     | 
    	
        models/codec/facodec/modules/attentions.py
    ADDED
    
    | 
         @@ -0,0 +1,437 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2023 Amphion.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            # This code is modified from https://github.com/sh-lee-prml/HierSpeechpp/blob/main/ttv_v1/attentions.py
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            import copy
         
     | 
| 9 | 
         
            +
            import math
         
     | 
| 10 | 
         
            +
            import numpy as np
         
     | 
| 11 | 
         
            +
            import torch
         
     | 
| 12 | 
         
            +
            from torch import nn
         
     | 
| 13 | 
         
            +
            from torch.nn import functional as F
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            from . import commons
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            class LayerNorm(nn.Module):
         
     | 
| 19 | 
         
            +
                def __init__(self, channels, eps=1e-5):
         
     | 
| 20 | 
         
            +
                    super().__init__()
         
     | 
| 21 | 
         
            +
                    self.channels = channels
         
     | 
| 22 | 
         
            +
                    self.eps = eps
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                    self.gamma = nn.Parameter(torch.ones(channels))
         
     | 
| 25 | 
         
            +
                    self.beta = nn.Parameter(torch.zeros(channels))
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                def forward(self, x):
         
     | 
| 28 | 
         
            +
                    x = x.transpose(1, -1)
         
     | 
| 29 | 
         
            +
                    x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
         
     | 
| 30 | 
         
            +
                    return x.transpose(1, -1)
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
            class Encoder(nn.Module):
         
     | 
| 34 | 
         
            +
                def __init__(
         
     | 
| 35 | 
         
            +
                    self,
         
     | 
| 36 | 
         
            +
                    hidden_channels,
         
     | 
| 37 | 
         
            +
                    filter_channels,
         
     | 
| 38 | 
         
            +
                    n_heads,
         
     | 
| 39 | 
         
            +
                    n_layers,
         
     | 
| 40 | 
         
            +
                    kernel_size=1,
         
     | 
| 41 | 
         
            +
                    p_dropout=0.0,
         
     | 
| 42 | 
         
            +
                    window_size=4,
         
     | 
| 43 | 
         
            +
                    **kwargs
         
     | 
| 44 | 
         
            +
                ):
         
     | 
| 45 | 
         
            +
                    super().__init__()
         
     | 
| 46 | 
         
            +
                    self.hidden_channels = hidden_channels
         
     | 
| 47 | 
         
            +
                    self.filter_channels = filter_channels
         
     | 
| 48 | 
         
            +
                    self.n_heads = n_heads
         
     | 
| 49 | 
         
            +
                    self.n_layers = n_layers
         
     | 
| 50 | 
         
            +
                    self.kernel_size = kernel_size
         
     | 
| 51 | 
         
            +
                    self.p_dropout = p_dropout
         
     | 
| 52 | 
         
            +
                    self.window_size = window_size
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                    self.drop = nn.Dropout(p_dropout)
         
     | 
| 55 | 
         
            +
                    self.attn_layers = nn.ModuleList()
         
     | 
| 56 | 
         
            +
                    self.norm_layers_1 = nn.ModuleList()
         
     | 
| 57 | 
         
            +
                    self.ffn_layers = nn.ModuleList()
         
     | 
| 58 | 
         
            +
                    self.norm_layers_2 = nn.ModuleList()
         
     | 
| 59 | 
         
            +
                    for i in range(self.n_layers):
         
     | 
| 60 | 
         
            +
                        self.attn_layers.append(
         
     | 
| 61 | 
         
            +
                            MultiHeadAttention(
         
     | 
| 62 | 
         
            +
                                hidden_channels,
         
     | 
| 63 | 
         
            +
                                hidden_channels,
         
     | 
| 64 | 
         
            +
                                n_heads,
         
     | 
| 65 | 
         
            +
                                p_dropout=p_dropout,
         
     | 
| 66 | 
         
            +
                                window_size=window_size,
         
     | 
| 67 | 
         
            +
                            )
         
     | 
| 68 | 
         
            +
                        )
         
     | 
| 69 | 
         
            +
                        self.norm_layers_1.append(LayerNorm(hidden_channels))
         
     | 
| 70 | 
         
            +
                        self.ffn_layers.append(
         
     | 
| 71 | 
         
            +
                            FFN(
         
     | 
| 72 | 
         
            +
                                hidden_channels,
         
     | 
| 73 | 
         
            +
                                hidden_channels,
         
     | 
| 74 | 
         
            +
                                filter_channels,
         
     | 
| 75 | 
         
            +
                                kernel_size,
         
     | 
| 76 | 
         
            +
                                p_dropout=p_dropout,
         
     | 
| 77 | 
         
            +
                            )
         
     | 
| 78 | 
         
            +
                        )
         
     | 
| 79 | 
         
            +
                        self.norm_layers_2.append(LayerNorm(hidden_channels))
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                def forward(self, x, x_mask):
         
     | 
| 82 | 
         
            +
                    attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
         
     | 
| 83 | 
         
            +
                    x = x * x_mask
         
     | 
| 84 | 
         
            +
                    for i in range(self.n_layers):
         
     | 
| 85 | 
         
            +
                        y = self.attn_layers[i](x, x, attn_mask)
         
     | 
| 86 | 
         
            +
                        y = self.drop(y)
         
     | 
| 87 | 
         
            +
                        x = self.norm_layers_1[i](x + y)
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                        y = self.ffn_layers[i](x, x_mask)
         
     | 
| 90 | 
         
            +
                        y = self.drop(y)
         
     | 
| 91 | 
         
            +
                        x = self.norm_layers_2[i](x + y)
         
     | 
| 92 | 
         
            +
                    x = x * x_mask
         
     | 
| 93 | 
         
            +
                    return x
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
            class Decoder(nn.Module):
         
     | 
| 97 | 
         
            +
                def __init__(
         
     | 
| 98 | 
         
            +
                    self,
         
     | 
| 99 | 
         
            +
                    hidden_channels,
         
     | 
| 100 | 
         
            +
                    filter_channels,
         
     | 
| 101 | 
         
            +
                    n_heads,
         
     | 
| 102 | 
         
            +
                    n_layers,
         
     | 
| 103 | 
         
            +
                    kernel_size=1,
         
     | 
| 104 | 
         
            +
                    p_dropout=0.0,
         
     | 
| 105 | 
         
            +
                    proximal_bias=False,
         
     | 
| 106 | 
         
            +
                    proximal_init=True,
         
     | 
| 107 | 
         
            +
                    **kwargs
         
     | 
| 108 | 
         
            +
                ):
         
     | 
| 109 | 
         
            +
                    super().__init__()
         
     | 
| 110 | 
         
            +
                    self.hidden_channels = hidden_channels
         
     | 
| 111 | 
         
            +
                    self.filter_channels = filter_channels
         
     | 
| 112 | 
         
            +
                    self.n_heads = n_heads
         
     | 
| 113 | 
         
            +
                    self.n_layers = n_layers
         
     | 
| 114 | 
         
            +
                    self.kernel_size = kernel_size
         
     | 
| 115 | 
         
            +
                    self.p_dropout = p_dropout
         
     | 
| 116 | 
         
            +
                    self.proximal_bias = proximal_bias
         
     | 
| 117 | 
         
            +
                    self.proximal_init = proximal_init
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                    self.drop = nn.Dropout(p_dropout)
         
     | 
| 120 | 
         
            +
                    self.self_attn_layers = nn.ModuleList()
         
     | 
| 121 | 
         
            +
                    self.norm_layers_0 = nn.ModuleList()
         
     | 
| 122 | 
         
            +
                    self.encdec_attn_layers = nn.ModuleList()
         
     | 
| 123 | 
         
            +
                    self.norm_layers_1 = nn.ModuleList()
         
     | 
| 124 | 
         
            +
                    self.ffn_layers = nn.ModuleList()
         
     | 
| 125 | 
         
            +
                    self.norm_layers_2 = nn.ModuleList()
         
     | 
| 126 | 
         
            +
                    for i in range(self.n_layers):
         
     | 
| 127 | 
         
            +
                        self.self_attn_layers.append(
         
     | 
| 128 | 
         
            +
                            MultiHeadAttention(
         
     | 
| 129 | 
         
            +
                                hidden_channels,
         
     | 
| 130 | 
         
            +
                                hidden_channels,
         
     | 
| 131 | 
         
            +
                                n_heads,
         
     | 
| 132 | 
         
            +
                                p_dropout=p_dropout,
         
     | 
| 133 | 
         
            +
                                proximal_bias=proximal_bias,
         
     | 
| 134 | 
         
            +
                                proximal_init=proximal_init,
         
     | 
| 135 | 
         
            +
                            )
         
     | 
| 136 | 
         
            +
                        )
         
     | 
| 137 | 
         
            +
                        self.norm_layers_0.append(LayerNorm(hidden_channels))
         
     | 
| 138 | 
         
            +
                        self.encdec_attn_layers.append(
         
     | 
| 139 | 
         
            +
                            MultiHeadAttention(
         
     | 
| 140 | 
         
            +
                                hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
         
     | 
| 141 | 
         
            +
                            )
         
     | 
| 142 | 
         
            +
                        )
         
     | 
| 143 | 
         
            +
                        self.norm_layers_1.append(LayerNorm(hidden_channels))
         
     | 
| 144 | 
         
            +
                        self.ffn_layers.append(
         
     | 
| 145 | 
         
            +
                            FFN(
         
     | 
| 146 | 
         
            +
                                hidden_channels,
         
     | 
| 147 | 
         
            +
                                hidden_channels,
         
     | 
| 148 | 
         
            +
                                filter_channels,
         
     | 
| 149 | 
         
            +
                                kernel_size,
         
     | 
| 150 | 
         
            +
                                p_dropout=p_dropout,
         
     | 
| 151 | 
         
            +
                                causal=True,
         
     | 
| 152 | 
         
            +
                            )
         
     | 
| 153 | 
         
            +
                        )
         
     | 
| 154 | 
         
            +
                        self.norm_layers_2.append(LayerNorm(hidden_channels))
         
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
                def forward(self, x, x_mask, h, h_mask):
         
     | 
| 157 | 
         
            +
                    """
         
     | 
| 158 | 
         
            +
                    x: decoder input
         
     | 
| 159 | 
         
            +
                    h: encoder output
         
     | 
| 160 | 
         
            +
                    """
         
     | 
| 161 | 
         
            +
                    self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
         
     | 
| 162 | 
         
            +
                        device=x.device, dtype=x.dtype
         
     | 
| 163 | 
         
            +
                    )
         
     | 
| 164 | 
         
            +
                    encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
         
     | 
| 165 | 
         
            +
                    x = x * x_mask
         
     | 
| 166 | 
         
            +
                    for i in range(self.n_layers):
         
     | 
| 167 | 
         
            +
                        y = self.self_attn_layers[i](x, x, self_attn_mask)
         
     | 
| 168 | 
         
            +
                        y = self.drop(y)
         
     | 
| 169 | 
         
            +
                        x = self.norm_layers_0[i](x + y)
         
     | 
| 170 | 
         
            +
             
     | 
| 171 | 
         
            +
                        y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
         
     | 
| 172 | 
         
            +
                        y = self.drop(y)
         
     | 
| 173 | 
         
            +
                        x = self.norm_layers_1[i](x + y)
         
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
            +
                        y = self.ffn_layers[i](x, x_mask)
         
     | 
| 176 | 
         
            +
                        y = self.drop(y)
         
     | 
| 177 | 
         
            +
                        x = self.norm_layers_2[i](x + y)
         
     | 
| 178 | 
         
            +
                    x = x * x_mask
         
     | 
| 179 | 
         
            +
                    return x
         
     | 
| 180 | 
         
            +
             
     | 
| 181 | 
         
            +
             
     | 
| 182 | 
         
            +
            class MultiHeadAttention(nn.Module):
         
     | 
| 183 | 
         
            +
                def __init__(
         
     | 
| 184 | 
         
            +
                    self,
         
     | 
| 185 | 
         
            +
                    channels,
         
     | 
| 186 | 
         
            +
                    out_channels,
         
     | 
| 187 | 
         
            +
                    n_heads,
         
     | 
| 188 | 
         
            +
                    p_dropout=0.0,
         
     | 
| 189 | 
         
            +
                    window_size=None,
         
     | 
| 190 | 
         
            +
                    heads_share=True,
         
     | 
| 191 | 
         
            +
                    block_length=None,
         
     | 
| 192 | 
         
            +
                    proximal_bias=False,
         
     | 
| 193 | 
         
            +
                    proximal_init=False,
         
     | 
| 194 | 
         
            +
                ):
         
     | 
| 195 | 
         
            +
                    super().__init__()
         
     | 
| 196 | 
         
            +
                    assert channels % n_heads == 0
         
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
            +
                    self.channels = channels
         
     | 
| 199 | 
         
            +
                    self.out_channels = out_channels
         
     | 
| 200 | 
         
            +
                    self.n_heads = n_heads
         
     | 
| 201 | 
         
            +
                    self.p_dropout = p_dropout
         
     | 
| 202 | 
         
            +
                    self.window_size = window_size
         
     | 
| 203 | 
         
            +
                    self.heads_share = heads_share
         
     | 
| 204 | 
         
            +
                    self.block_length = block_length
         
     | 
| 205 | 
         
            +
                    self.proximal_bias = proximal_bias
         
     | 
| 206 | 
         
            +
                    self.proximal_init = proximal_init
         
     | 
| 207 | 
         
            +
                    self.attn = None
         
     | 
| 208 | 
         
            +
             
     | 
| 209 | 
         
            +
                    self.k_channels = channels // n_heads
         
     | 
| 210 | 
         
            +
                    self.conv_q = nn.Conv1d(channels, channels, 1)
         
     | 
| 211 | 
         
            +
                    self.conv_k = nn.Conv1d(channels, channels, 1)
         
     | 
| 212 | 
         
            +
                    self.conv_v = nn.Conv1d(channels, channels, 1)
         
     | 
| 213 | 
         
            +
                    self.conv_o = nn.Conv1d(channels, out_channels, 1)
         
     | 
| 214 | 
         
            +
                    self.drop = nn.Dropout(p_dropout)
         
     | 
| 215 | 
         
            +
             
     | 
| 216 | 
         
            +
                    if window_size is not None:
         
     | 
| 217 | 
         
            +
                        n_heads_rel = 1 if heads_share else n_heads
         
     | 
| 218 | 
         
            +
                        rel_stddev = self.k_channels**-0.5
         
     | 
| 219 | 
         
            +
                        self.emb_rel_k = nn.Parameter(
         
     | 
| 220 | 
         
            +
                            torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
         
     | 
| 221 | 
         
            +
                            * rel_stddev
         
     | 
| 222 | 
         
            +
                        )
         
     | 
| 223 | 
         
            +
                        self.emb_rel_v = nn.Parameter(
         
     | 
| 224 | 
         
            +
                            torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
         
     | 
| 225 | 
         
            +
                            * rel_stddev
         
     | 
| 226 | 
         
            +
                        )
         
     | 
| 227 | 
         
            +
             
     | 
| 228 | 
         
            +
                    nn.init.xavier_uniform_(self.conv_q.weight)
         
     | 
| 229 | 
         
            +
                    nn.init.xavier_uniform_(self.conv_k.weight)
         
     | 
| 230 | 
         
            +
                    nn.init.xavier_uniform_(self.conv_v.weight)
         
     | 
| 231 | 
         
            +
                    if proximal_init:
         
     | 
| 232 | 
         
            +
                        with torch.no_grad():
         
     | 
| 233 | 
         
            +
                            self.conv_k.weight.copy_(self.conv_q.weight)
         
     | 
| 234 | 
         
            +
                            self.conv_k.bias.copy_(self.conv_q.bias)
         
     | 
| 235 | 
         
            +
             
     | 
| 236 | 
         
            +
                def forward(self, x, c, attn_mask=None):
         
     | 
| 237 | 
         
            +
                    q = self.conv_q(x)
         
     | 
| 238 | 
         
            +
                    k = self.conv_k(c)
         
     | 
| 239 | 
         
            +
                    v = self.conv_v(c)
         
     | 
| 240 | 
         
            +
             
     | 
| 241 | 
         
            +
                    x, self.attn = self.attention(q, k, v, mask=attn_mask)
         
     | 
| 242 | 
         
            +
             
     | 
| 243 | 
         
            +
                    x = self.conv_o(x)
         
     | 
| 244 | 
         
            +
                    return x
         
     | 
| 245 | 
         
            +
             
     | 
| 246 | 
         
            +
                def attention(self, query, key, value, mask=None):
         
     | 
| 247 | 
         
            +
                    # reshape [b, d, t] -> [b, n_h, t, d_k]
         
     | 
| 248 | 
         
            +
                    b, d, t_s, t_t = (*key.size(), query.size(2))
         
     | 
| 249 | 
         
            +
                    query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
         
     | 
| 250 | 
         
            +
                    key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
         
     | 
| 251 | 
         
            +
                    value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
         
     | 
| 252 | 
         
            +
             
     | 
| 253 | 
         
            +
                    scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
         
     | 
| 254 | 
         
            +
                    if self.window_size is not None:
         
     | 
| 255 | 
         
            +
                        assert (
         
     | 
| 256 | 
         
            +
                            t_s == t_t
         
     | 
| 257 | 
         
            +
                        ), "Relative attention is only available for self-attention."
         
     | 
| 258 | 
         
            +
                        key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
         
     | 
| 259 | 
         
            +
                        rel_logits = self._matmul_with_relative_keys(
         
     | 
| 260 | 
         
            +
                            query / math.sqrt(self.k_channels), key_relative_embeddings
         
     | 
| 261 | 
         
            +
                        )
         
     | 
| 262 | 
         
            +
                        scores_local = self._relative_position_to_absolute_position(rel_logits)
         
     | 
| 263 | 
         
            +
                        scores = scores + scores_local
         
     | 
| 264 | 
         
            +
                    if self.proximal_bias:
         
     | 
| 265 | 
         
            +
                        assert t_s == t_t, "Proximal bias is only available for self-attention."
         
     | 
| 266 | 
         
            +
                        scores = scores + self._attention_bias_proximal(t_s).to(
         
     | 
| 267 | 
         
            +
                            device=scores.device, dtype=scores.dtype
         
     | 
| 268 | 
         
            +
                        )
         
     | 
| 269 | 
         
            +
                    if mask is not None:
         
     | 
| 270 | 
         
            +
                        scores = scores.masked_fill(mask == 0, -1e4)
         
     | 
| 271 | 
         
            +
                        if self.block_length is not None:
         
     | 
| 272 | 
         
            +
                            assert (
         
     | 
| 273 | 
         
            +
                                t_s == t_t
         
     | 
| 274 | 
         
            +
                            ), "Local attention is only available for self-attention."
         
     | 
| 275 | 
         
            +
                            block_mask = (
         
     | 
| 276 | 
         
            +
                                torch.ones_like(scores)
         
     | 
| 277 | 
         
            +
                                .triu(-self.block_length)
         
     | 
| 278 | 
         
            +
                                .tril(self.block_length)
         
     | 
| 279 | 
         
            +
                            )
         
     | 
| 280 | 
         
            +
                            scores = scores.masked_fill(block_mask == 0, -1e4)
         
     | 
| 281 | 
         
            +
                    p_attn = F.softmax(scores, dim=-1)  # [b, n_h, t_t, t_s]
         
     | 
| 282 | 
         
            +
                    p_attn = self.drop(p_attn)
         
     | 
| 283 | 
         
            +
                    output = torch.matmul(p_attn, value)
         
     | 
| 284 | 
         
            +
                    if self.window_size is not None:
         
     | 
| 285 | 
         
            +
                        relative_weights = self._absolute_position_to_relative_position(p_attn)
         
     | 
| 286 | 
         
            +
                        value_relative_embeddings = self._get_relative_embeddings(
         
     | 
| 287 | 
         
            +
                            self.emb_rel_v, t_s
         
     | 
| 288 | 
         
            +
                        )
         
     | 
| 289 | 
         
            +
                        output = output + self._matmul_with_relative_values(
         
     | 
| 290 | 
         
            +
                            relative_weights, value_relative_embeddings
         
     | 
| 291 | 
         
            +
                        )
         
     | 
| 292 | 
         
            +
                    output = (
         
     | 
| 293 | 
         
            +
                        output.transpose(2, 3).contiguous().view(b, d, t_t)
         
     | 
| 294 | 
         
            +
                    )  # [b, n_h, t_t, d_k] -> [b, d, t_t]
         
     | 
| 295 | 
         
            +
                    return output, p_attn
         
     | 
| 296 | 
         
            +
             
     | 
| 297 | 
         
            +
                def _matmul_with_relative_values(self, x, y):
         
     | 
| 298 | 
         
            +
                    """
         
     | 
| 299 | 
         
            +
                    x: [b, h, l, m]
         
     | 
| 300 | 
         
            +
                    y: [h or 1, m, d]
         
     | 
| 301 | 
         
            +
                    ret: [b, h, l, d]
         
     | 
| 302 | 
         
            +
                    """
         
     | 
| 303 | 
         
            +
                    ret = torch.matmul(x, y.unsqueeze(0))
         
     | 
| 304 | 
         
            +
                    return ret
         
     | 
| 305 | 
         
            +
             
     | 
| 306 | 
         
            +
                def _matmul_with_relative_keys(self, x, y):
         
     | 
| 307 | 
         
            +
                    """
         
     | 
| 308 | 
         
            +
                    x: [b, h, l, d]
         
     | 
| 309 | 
         
            +
                    y: [h or 1, m, d]
         
     | 
| 310 | 
         
            +
                    ret: [b, h, l, m]
         
     | 
| 311 | 
         
            +
                    """
         
     | 
| 312 | 
         
            +
                    ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
         
     | 
| 313 | 
         
            +
                    return ret
         
     | 
| 314 | 
         
            +
             
     | 
| 315 | 
         
            +
                def _get_relative_embeddings(self, relative_embeddings, length):
         
     | 
| 316 | 
         
            +
                    max_relative_position = 2 * self.window_size + 1
         
     | 
| 317 | 
         
            +
                    # Pad first before slice to avoid using cond ops.
         
     | 
| 318 | 
         
            +
                    pad_length = max(length - (self.window_size + 1), 0)
         
     | 
| 319 | 
         
            +
                    slice_start_position = max((self.window_size + 1) - length, 0)
         
     | 
| 320 | 
         
            +
                    slice_end_position = slice_start_position + 2 * length - 1
         
     | 
| 321 | 
         
            +
                    if pad_length > 0:
         
     | 
| 322 | 
         
            +
                        padded_relative_embeddings = F.pad(
         
     | 
| 323 | 
         
            +
                            relative_embeddings,
         
     | 
| 324 | 
         
            +
                            commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
         
     | 
| 325 | 
         
            +
                        )
         
     | 
| 326 | 
         
            +
                    else:
         
     | 
| 327 | 
         
            +
                        padded_relative_embeddings = relative_embeddings
         
     | 
| 328 | 
         
            +
                    used_relative_embeddings = padded_relative_embeddings[
         
     | 
| 329 | 
         
            +
                        :, slice_start_position:slice_end_position
         
     | 
| 330 | 
         
            +
                    ]
         
     | 
| 331 | 
         
            +
                    return used_relative_embeddings
         
     | 
| 332 | 
         
            +
             
     | 
| 333 | 
         
            +
                def _relative_position_to_absolute_position(self, x):
         
     | 
| 334 | 
         
            +
                    """
         
     | 
| 335 | 
         
            +
                    x: [b, h, l, 2*l-1]
         
     | 
| 336 | 
         
            +
                    ret: [b, h, l, l]
         
     | 
| 337 | 
         
            +
                    """
         
     | 
| 338 | 
         
            +
                    batch, heads, length, _ = x.size()
         
     | 
| 339 | 
         
            +
                    # Concat columns of pad to shift from relative to absolute indexing.
         
     | 
| 340 | 
         
            +
                    x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
         
     | 
| 341 | 
         
            +
             
     | 
| 342 | 
         
            +
                    # Concat extra elements so to add up to shape (len+1, 2*len-1).
         
     | 
| 343 | 
         
            +
                    x_flat = x.view([batch, heads, length * 2 * length])
         
     | 
| 344 | 
         
            +
                    x_flat = F.pad(
         
     | 
| 345 | 
         
            +
                        x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
         
     | 
| 346 | 
         
            +
                    )
         
     | 
| 347 | 
         
            +
             
     | 
| 348 | 
         
            +
                    # Reshape and slice out the padded elements.
         
     | 
| 349 | 
         
            +
                    x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
         
     | 
| 350 | 
         
            +
                        :, :, :length, length - 1 :
         
     | 
| 351 | 
         
            +
                    ]
         
     | 
| 352 | 
         
            +
                    return x_final
         
     | 
| 353 | 
         
            +
             
     | 
| 354 | 
         
            +
                def _absolute_position_to_relative_position(self, x):
         
     | 
| 355 | 
         
            +
                    """
         
     | 
| 356 | 
         
            +
                    x: [b, h, l, l]
         
     | 
| 357 | 
         
            +
                    ret: [b, h, l, 2*l-1]
         
     | 
| 358 | 
         
            +
                    """
         
     | 
| 359 | 
         
            +
                    batch, heads, length, _ = x.size()
         
     | 
| 360 | 
         
            +
                    # padd along column
         
     | 
| 361 | 
         
            +
                    x = F.pad(
         
     | 
| 362 | 
         
            +
                        x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
         
     | 
| 363 | 
         
            +
                    )
         
     | 
| 364 | 
         
            +
                    x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
         
     | 
| 365 | 
         
            +
                    # add 0's in the beginning that will skew the elements after reshape
         
     | 
| 366 | 
         
            +
                    x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
         
     | 
| 367 | 
         
            +
                    x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
         
     | 
| 368 | 
         
            +
                    return x_final
         
     | 
| 369 | 
         
            +
             
     | 
| 370 | 
         
            +
                def _attention_bias_proximal(self, length):
         
     | 
| 371 | 
         
            +
                    """Bias for self-attention to encourage attention to close positions.
         
     | 
| 372 | 
         
            +
                    Args:
         
     | 
| 373 | 
         
            +
                      length: an integer scalar.
         
     | 
| 374 | 
         
            +
                    Returns:
         
     | 
| 375 | 
         
            +
                      a Tensor with shape [1, 1, length, length]
         
     | 
| 376 | 
         
            +
                    """
         
     | 
| 377 | 
         
            +
                    r = torch.arange(length, dtype=torch.float32)
         
     | 
| 378 | 
         
            +
                    diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
         
     | 
| 379 | 
         
            +
                    return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
         
     | 
| 380 | 
         
            +
             
     | 
| 381 | 
         
            +
             
     | 
| 382 | 
         
            +
            class FFN(nn.Module):
         
     | 
| 383 | 
         
            +
                def __init__(
         
     | 
| 384 | 
         
            +
                    self,
         
     | 
| 385 | 
         
            +
                    in_channels,
         
     | 
| 386 | 
         
            +
                    out_channels,
         
     | 
| 387 | 
         
            +
                    filter_channels,
         
     | 
| 388 | 
         
            +
                    kernel_size,
         
     | 
| 389 | 
         
            +
                    p_dropout=0.0,
         
     | 
| 390 | 
         
            +
                    activation=None,
         
     | 
| 391 | 
         
            +
                    causal=False,
         
     | 
| 392 | 
         
            +
                ):
         
     | 
| 393 | 
         
            +
                    super().__init__()
         
     | 
| 394 | 
         
            +
                    self.in_channels = in_channels
         
     | 
| 395 | 
         
            +
                    self.out_channels = out_channels
         
     | 
| 396 | 
         
            +
                    self.filter_channels = filter_channels
         
     | 
| 397 | 
         
            +
                    self.kernel_size = kernel_size
         
     | 
| 398 | 
         
            +
                    self.p_dropout = p_dropout
         
     | 
| 399 | 
         
            +
                    self.activation = activation
         
     | 
| 400 | 
         
            +
                    self.causal = causal
         
     | 
| 401 | 
         
            +
             
     | 
| 402 | 
         
            +
                    if causal:
         
     | 
| 403 | 
         
            +
                        self.padding = self._causal_padding
         
     | 
| 404 | 
         
            +
                    else:
         
     | 
| 405 | 
         
            +
                        self.padding = self._same_padding
         
     | 
| 406 | 
         
            +
             
     | 
| 407 | 
         
            +
                    self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
         
     | 
| 408 | 
         
            +
                    self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
         
     | 
| 409 | 
         
            +
                    self.drop = nn.Dropout(p_dropout)
         
     | 
| 410 | 
         
            +
             
     | 
| 411 | 
         
            +
                def forward(self, x, x_mask):
         
     | 
| 412 | 
         
            +
                    x = self.conv_1(self.padding(x * x_mask))
         
     | 
| 413 | 
         
            +
                    if self.activation == "gelu":
         
     | 
| 414 | 
         
            +
                        x = x * torch.sigmoid(1.702 * x)
         
     | 
| 415 | 
         
            +
                    else:
         
     | 
| 416 | 
         
            +
                        x = torch.relu(x)
         
     | 
| 417 | 
         
            +
                    x = self.drop(x)
         
     | 
| 418 | 
         
            +
                    x = self.conv_2(self.padding(x * x_mask))
         
     | 
| 419 | 
         
            +
                    return x * x_mask
         
     | 
| 420 | 
         
            +
             
     | 
| 421 | 
         
            +
                def _causal_padding(self, x):
         
     | 
| 422 | 
         
            +
                    if self.kernel_size == 1:
         
     | 
| 423 | 
         
            +
                        return x
         
     | 
| 424 | 
         
            +
                    pad_l = self.kernel_size - 1
         
     | 
| 425 | 
         
            +
                    pad_r = 0
         
     | 
| 426 | 
         
            +
                    padding = [[0, 0], [0, 0], [pad_l, pad_r]]
         
     | 
| 427 | 
         
            +
                    x = F.pad(x, commons.convert_pad_shape(padding))
         
     | 
| 428 | 
         
            +
                    return x
         
     | 
| 429 | 
         
            +
             
     | 
| 430 | 
         
            +
                def _same_padding(self, x):
         
     | 
| 431 | 
         
            +
                    if self.kernel_size == 1:
         
     | 
| 432 | 
         
            +
                        return x
         
     | 
| 433 | 
         
            +
                    pad_l = (self.kernel_size - 1) // 2
         
     | 
| 434 | 
         
            +
                    pad_r = self.kernel_size // 2
         
     | 
| 435 | 
         
            +
                    padding = [[0, 0], [0, 0], [pad_l, pad_r]]
         
     | 
| 436 | 
         
            +
                    x = F.pad(x, commons.convert_pad_shape(padding))
         
     | 
| 437 | 
         
            +
                    return x
         
     | 
    	
        models/codec/facodec/modules/commons.py
    ADDED
    
    | 
         @@ -0,0 +1,331 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2023 Amphion.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            import math
         
     | 
| 8 | 
         
            +
            import os.path
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            import numpy as np
         
     | 
| 11 | 
         
            +
            import torch
         
     | 
| 12 | 
         
            +
            from torch import nn
         
     | 
| 13 | 
         
            +
            from torch.nn import functional as F
         
     | 
| 14 | 
         
            +
            from munch import Munch
         
     | 
| 15 | 
         
            +
            import json
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            class AttrDict(dict):
         
     | 
| 19 | 
         
            +
                def __init__(self, *args, **kwargs):
         
     | 
| 20 | 
         
            +
                    super(AttrDict, self).__init__(*args, **kwargs)
         
     | 
| 21 | 
         
            +
                    self.__dict__ = self
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            def init_weights(m, mean=0.0, std=0.01):
         
     | 
| 25 | 
         
            +
                classname = m.__class__.__name__
         
     | 
| 26 | 
         
            +
                if classname.find("Conv") != -1:
         
     | 
| 27 | 
         
            +
                    m.weight.data.normal_(mean, std)
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
            def get_padding(kernel_size, dilation=1):
         
     | 
| 31 | 
         
            +
                return int((kernel_size * dilation - dilation) / 2)
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
            def convert_pad_shape(pad_shape):
         
     | 
| 35 | 
         
            +
                l = pad_shape[::-1]
         
     | 
| 36 | 
         
            +
                pad_shape = [item for sublist in l for item in sublist]
         
     | 
| 37 | 
         
            +
                return pad_shape
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
            def intersperse(lst, item):
         
     | 
| 41 | 
         
            +
                result = [item] * (len(lst) * 2 + 1)
         
     | 
| 42 | 
         
            +
                result[1::2] = lst
         
     | 
| 43 | 
         
            +
                return result
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
            def kl_divergence(m_p, logs_p, m_q, logs_q):
         
     | 
| 47 | 
         
            +
                """KL(P||Q)"""
         
     | 
| 48 | 
         
            +
                kl = (logs_q - logs_p) - 0.5
         
     | 
| 49 | 
         
            +
                kl += (
         
     | 
| 50 | 
         
            +
                    0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
         
     | 
| 51 | 
         
            +
                )
         
     | 
| 52 | 
         
            +
                return kl
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
            def rand_gumbel(shape):
         
     | 
| 56 | 
         
            +
                """Sample from the Gumbel distribution, protect from overflows."""
         
     | 
| 57 | 
         
            +
                uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
         
     | 
| 58 | 
         
            +
                return -torch.log(-torch.log(uniform_samples))
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
            def rand_gumbel_like(x):
         
     | 
| 62 | 
         
            +
                g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
         
     | 
| 63 | 
         
            +
                return g
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
            def slice_segments(x, ids_str, segment_size=4):
         
     | 
| 67 | 
         
            +
                ret = torch.zeros_like(x[:, :, :segment_size])
         
     | 
| 68 | 
         
            +
                for i in range(x.size(0)):
         
     | 
| 69 | 
         
            +
                    idx_str = ids_str[i]
         
     | 
| 70 | 
         
            +
                    idx_end = idx_str + segment_size
         
     | 
| 71 | 
         
            +
                    ret[i] = x[i, :, idx_str:idx_end]
         
     | 
| 72 | 
         
            +
                return ret
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
            def slice_segments_audio(x, ids_str, segment_size=4):
         
     | 
| 76 | 
         
            +
                ret = torch.zeros_like(x[:, :segment_size])
         
     | 
| 77 | 
         
            +
                for i in range(x.size(0)):
         
     | 
| 78 | 
         
            +
                    idx_str = ids_str[i]
         
     | 
| 79 | 
         
            +
                    idx_end = idx_str + segment_size
         
     | 
| 80 | 
         
            +
                    ret[i] = x[i, idx_str:idx_end]
         
     | 
| 81 | 
         
            +
                return ret
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
            def rand_slice_segments(x, x_lengths=None, segment_size=4):
         
     | 
| 85 | 
         
            +
                b, d, t = x.size()
         
     | 
| 86 | 
         
            +
                if x_lengths is None:
         
     | 
| 87 | 
         
            +
                    x_lengths = t
         
     | 
| 88 | 
         
            +
                ids_str_max = x_lengths - segment_size + 1
         
     | 
| 89 | 
         
            +
                ids_str = ((torch.rand([b]).to(device=x.device) * ids_str_max).clip(0)).to(
         
     | 
| 90 | 
         
            +
                    dtype=torch.long
         
     | 
| 91 | 
         
            +
                )
         
     | 
| 92 | 
         
            +
                ret = slice_segments(x, ids_str, segment_size)
         
     | 
| 93 | 
         
            +
                return ret, ids_str
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
            def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
         
     | 
| 97 | 
         
            +
                position = torch.arange(length, dtype=torch.float)
         
     | 
| 98 | 
         
            +
                num_timescales = channels // 2
         
     | 
| 99 | 
         
            +
                log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
         
     | 
| 100 | 
         
            +
                    num_timescales - 1
         
     | 
| 101 | 
         
            +
                )
         
     | 
| 102 | 
         
            +
                inv_timescales = min_timescale * torch.exp(
         
     | 
| 103 | 
         
            +
                    torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
         
     | 
| 104 | 
         
            +
                )
         
     | 
| 105 | 
         
            +
                scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
         
     | 
| 106 | 
         
            +
                signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
         
     | 
| 107 | 
         
            +
                signal = F.pad(signal, [0, 0, 0, channels % 2])
         
     | 
| 108 | 
         
            +
                signal = signal.view(1, channels, length)
         
     | 
| 109 | 
         
            +
                return signal
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
            def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
         
     | 
| 113 | 
         
            +
                b, channels, length = x.size()
         
     | 
| 114 | 
         
            +
                signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
         
     | 
| 115 | 
         
            +
                return x + signal.to(dtype=x.dtype, device=x.device)
         
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
            def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
         
     | 
| 119 | 
         
            +
                b, channels, length = x.size()
         
     | 
| 120 | 
         
            +
                signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
         
     | 
| 121 | 
         
            +
                return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
            def subsequent_mask(length):
         
     | 
| 125 | 
         
            +
                mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
         
     | 
| 126 | 
         
            +
                return mask
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
            @torch.jit.script
         
     | 
| 130 | 
         
            +
            def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
         
     | 
| 131 | 
         
            +
                n_channels_int = n_channels[0]
         
     | 
| 132 | 
         
            +
                in_act = input_a + input_b
         
     | 
| 133 | 
         
            +
                t_act = torch.tanh(in_act[:, :n_channels_int, :])
         
     | 
| 134 | 
         
            +
                s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
         
     | 
| 135 | 
         
            +
                acts = t_act * s_act
         
     | 
| 136 | 
         
            +
                return acts
         
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
            def convert_pad_shape(pad_shape):
         
     | 
| 140 | 
         
            +
                l = pad_shape[::-1]
         
     | 
| 141 | 
         
            +
                pad_shape = [item for sublist in l for item in sublist]
         
     | 
| 142 | 
         
            +
                return pad_shape
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
            def shift_1d(x):
         
     | 
| 146 | 
         
            +
                x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
         
     | 
| 147 | 
         
            +
                return x
         
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
             
     | 
| 150 | 
         
            +
            def sequence_mask(length, max_length=None):
         
     | 
| 151 | 
         
            +
                if max_length is None:
         
     | 
| 152 | 
         
            +
                    max_length = length.max()
         
     | 
| 153 | 
         
            +
                x = torch.arange(max_length, dtype=length.dtype, device=length.device)
         
     | 
| 154 | 
         
            +
                return x.unsqueeze(0) < length.unsqueeze(1)
         
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
             
     | 
| 157 | 
         
            +
            def generate_path(duration, mask):
         
     | 
| 158 | 
         
            +
                """
         
     | 
| 159 | 
         
            +
                duration: [b, 1, t_x]
         
     | 
| 160 | 
         
            +
                mask: [b, 1, t_y, t_x]
         
     | 
| 161 | 
         
            +
                """
         
     | 
| 162 | 
         
            +
                device = duration.device
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
                b, _, t_y, t_x = mask.shape
         
     | 
| 165 | 
         
            +
                cum_duration = torch.cumsum(duration, -1)
         
     | 
| 166 | 
         
            +
             
     | 
| 167 | 
         
            +
                cum_duration_flat = cum_duration.view(b * t_x)
         
     | 
| 168 | 
         
            +
                path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
         
     | 
| 169 | 
         
            +
                path = path.view(b, t_x, t_y)
         
     | 
| 170 | 
         
            +
                path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
         
     | 
| 171 | 
         
            +
                path = path.unsqueeze(1).transpose(2, 3) * mask
         
     | 
| 172 | 
         
            +
                return path
         
     | 
| 173 | 
         
            +
             
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
            +
            def clip_grad_value_(parameters, clip_value, norm_type=2):
         
     | 
| 176 | 
         
            +
                if isinstance(parameters, torch.Tensor):
         
     | 
| 177 | 
         
            +
                    parameters = [parameters]
         
     | 
| 178 | 
         
            +
                parameters = list(filter(lambda p: p.grad is not None, parameters))
         
     | 
| 179 | 
         
            +
                norm_type = float(norm_type)
         
     | 
| 180 | 
         
            +
                if clip_value is not None:
         
     | 
| 181 | 
         
            +
                    clip_value = float(clip_value)
         
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
                total_norm = 0
         
     | 
| 184 | 
         
            +
                for p in parameters:
         
     | 
| 185 | 
         
            +
                    param_norm = p.grad.data.norm(norm_type)
         
     | 
| 186 | 
         
            +
                    total_norm += param_norm.item() ** norm_type
         
     | 
| 187 | 
         
            +
                    if clip_value is not None:
         
     | 
| 188 | 
         
            +
                        p.grad.data.clamp_(min=-clip_value, max=clip_value)
         
     | 
| 189 | 
         
            +
                total_norm = total_norm ** (1.0 / norm_type)
         
     | 
| 190 | 
         
            +
                return total_norm
         
     | 
| 191 | 
         
            +
             
     | 
| 192 | 
         
            +
             
     | 
| 193 | 
         
            +
            def log_norm(x, mean=-4, std=4, dim=2):
         
     | 
| 194 | 
         
            +
                """
         
     | 
| 195 | 
         
            +
                normalized log mel -> mel -> norm -> log(norm)
         
     | 
| 196 | 
         
            +
                """
         
     | 
| 197 | 
         
            +
                x = torch.log(torch.exp(x * std + mean).norm(dim=dim))
         
     | 
| 198 | 
         
            +
                return x
         
     | 
| 199 | 
         
            +
             
     | 
| 200 | 
         
            +
             
     | 
| 201 | 
         
            +
            from huggingface_hub import hf_hub_download
         
     | 
| 202 | 
         
            +
             
     | 
| 203 | 
         
            +
             
     | 
| 204 | 
         
            +
            def load_F0_models(path):
         
     | 
| 205 | 
         
            +
                # load F0 model
         
     | 
| 206 | 
         
            +
                from .JDC.model import JDCNet
         
     | 
| 207 | 
         
            +
             
     | 
| 208 | 
         
            +
                F0_model = JDCNet(num_class=1, seq_len=192)
         
     | 
| 209 | 
         
            +
                if not os.path.exists(path):
         
     | 
| 210 | 
         
            +
                    path = hf_hub_download(repo_id="Plachta/JDCnet", filename="bst.t7")
         
     | 
| 211 | 
         
            +
                params = torch.load(path, map_location="cpu")["net"]
         
     | 
| 212 | 
         
            +
                F0_model.load_state_dict(params)
         
     | 
| 213 | 
         
            +
                _ = F0_model.train()
         
     | 
| 214 | 
         
            +
             
     | 
| 215 | 
         
            +
                return F0_model
         
     | 
| 216 | 
         
            +
             
     | 
| 217 | 
         
            +
             
     | 
| 218 | 
         
            +
            # Generators
         
     | 
| 219 | 
         
            +
            from modules.dac.model.dac import Encoder, Decoder
         
     | 
| 220 | 
         
            +
            from .quantize import FAquantizer, FApredictors
         
     | 
| 221 | 
         
            +
             
     | 
| 222 | 
         
            +
            # Discriminators
         
     | 
| 223 | 
         
            +
            from modules.dac.model.discriminator import Discriminator
         
     | 
| 224 | 
         
            +
             
     | 
| 225 | 
         
            +
             
     | 
| 226 | 
         
            +
            def build_model(args):
         
     | 
| 227 | 
         
            +
                encoder = Encoder(
         
     | 
| 228 | 
         
            +
                    d_model=args.DAC.encoder_dim,
         
     | 
| 229 | 
         
            +
                    strides=args.DAC.encoder_rates,
         
     | 
| 230 | 
         
            +
                    d_latent=1024,
         
     | 
| 231 | 
         
            +
                    causal=args.causal,
         
     | 
| 232 | 
         
            +
                    lstm=args.lstm,
         
     | 
| 233 | 
         
            +
                )
         
     | 
| 234 | 
         
            +
             
     | 
| 235 | 
         
            +
                quantizer = FAquantizer(
         
     | 
| 236 | 
         
            +
                    in_dim=1024,
         
     | 
| 237 | 
         
            +
                    n_p_codebooks=1,
         
     | 
| 238 | 
         
            +
                    n_c_codebooks=args.n_c_codebooks,
         
     | 
| 239 | 
         
            +
                    n_t_codebooks=2,
         
     | 
| 240 | 
         
            +
                    n_r_codebooks=3,
         
     | 
| 241 | 
         
            +
                    codebook_size=1024,
         
     | 
| 242 | 
         
            +
                    codebook_dim=8,
         
     | 
| 243 | 
         
            +
                    quantizer_dropout=0.5,
         
     | 
| 244 | 
         
            +
                    causal=args.causal,
         
     | 
| 245 | 
         
            +
                    separate_prosody_encoder=args.separate_prosody_encoder,
         
     | 
| 246 | 
         
            +
                    timbre_norm=args.timbre_norm,
         
     | 
| 247 | 
         
            +
                )
         
     | 
| 248 | 
         
            +
             
     | 
| 249 | 
         
            +
                fa_predictors = FApredictors(
         
     | 
| 250 | 
         
            +
                    in_dim=1024,
         
     | 
| 251 | 
         
            +
                    use_gr_content_f0=args.use_gr_content_f0,
         
     | 
| 252 | 
         
            +
                    use_gr_prosody_phone=args.use_gr_prosody_phone,
         
     | 
| 253 | 
         
            +
                    use_gr_residual_f0=True,
         
     | 
| 254 | 
         
            +
                    use_gr_residual_phone=True,
         
     | 
| 255 | 
         
            +
                    use_gr_timbre_content=True,
         
     | 
| 256 | 
         
            +
                    use_gr_timbre_prosody=args.use_gr_timbre_prosody,
         
     | 
| 257 | 
         
            +
                    use_gr_x_timbre=True,
         
     | 
| 258 | 
         
            +
                    norm_f0=args.norm_f0,
         
     | 
| 259 | 
         
            +
                    timbre_norm=args.timbre_norm,
         
     | 
| 260 | 
         
            +
                    use_gr_content_global_f0=args.use_gr_content_global_f0,
         
     | 
| 261 | 
         
            +
                )
         
     | 
| 262 | 
         
            +
             
     | 
| 263 | 
         
            +
                decoder = Decoder(
         
     | 
| 264 | 
         
            +
                    input_channel=1024,
         
     | 
| 265 | 
         
            +
                    channels=args.DAC.decoder_dim,
         
     | 
| 266 | 
         
            +
                    rates=args.DAC.decoder_rates,
         
     | 
| 267 | 
         
            +
                    causal=args.causal,
         
     | 
| 268 | 
         
            +
                    lstm=args.lstm,
         
     | 
| 269 | 
         
            +
                )
         
     | 
| 270 | 
         
            +
             
     | 
| 271 | 
         
            +
                discriminator = Discriminator(
         
     | 
| 272 | 
         
            +
                    rates=[],
         
     | 
| 273 | 
         
            +
                    periods=[2, 3, 5, 7, 11],
         
     | 
| 274 | 
         
            +
                    fft_sizes=[2048, 1024, 512],
         
     | 
| 275 | 
         
            +
                    sample_rate=args.DAC.sr,
         
     | 
| 276 | 
         
            +
                    bands=[(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)],
         
     | 
| 277 | 
         
            +
                )
         
     | 
| 278 | 
         
            +
             
     | 
| 279 | 
         
            +
                nets = Munch(
         
     | 
| 280 | 
         
            +
                    encoder=encoder,
         
     | 
| 281 | 
         
            +
                    quantizer=quantizer,
         
     | 
| 282 | 
         
            +
                    decoder=decoder,
         
     | 
| 283 | 
         
            +
                    discriminator=discriminator,
         
     | 
| 284 | 
         
            +
                    fa_predictors=fa_predictors,
         
     | 
| 285 | 
         
            +
                )
         
     | 
| 286 | 
         
            +
             
     | 
| 287 | 
         
            +
                return nets
         
     | 
| 288 | 
         
            +
             
     | 
| 289 | 
         
            +
             
     | 
| 290 | 
         
            +
            def load_checkpoint(
         
     | 
| 291 | 
         
            +
                model,
         
     | 
| 292 | 
         
            +
                optimizer,
         
     | 
| 293 | 
         
            +
                path,
         
     | 
| 294 | 
         
            +
                load_only_params=True,
         
     | 
| 295 | 
         
            +
                ignore_modules=[],
         
     | 
| 296 | 
         
            +
                is_distributed=False,
         
     | 
| 297 | 
         
            +
            ):
         
     | 
| 298 | 
         
            +
                state = torch.load(path, map_location="cpu")
         
     | 
| 299 | 
         
            +
                params = state["net"]
         
     | 
| 300 | 
         
            +
                for key in model:
         
     | 
| 301 | 
         
            +
                    if key in params and key not in ignore_modules:
         
     | 
| 302 | 
         
            +
                        if not is_distributed:
         
     | 
| 303 | 
         
            +
                            # strip prefix of DDP (module.), create a new OrderedDict that does not contain the prefix
         
     | 
| 304 | 
         
            +
                            for k in list(params[key].keys()):
         
     | 
| 305 | 
         
            +
                                if k.startswith("module."):
         
     | 
| 306 | 
         
            +
                                    params[key][k[len("module.") :]] = params[key][k]
         
     | 
| 307 | 
         
            +
                                    del params[key][k]
         
     | 
| 308 | 
         
            +
                        print("%s loaded" % key)
         
     | 
| 309 | 
         
            +
                        model[key].load_state_dict(params[key], strict=True)
         
     | 
| 310 | 
         
            +
                _ = [model[key].eval() for key in model]
         
     | 
| 311 | 
         
            +
             
     | 
| 312 | 
         
            +
                if not load_only_params:
         
     | 
| 313 | 
         
            +
                    epoch = state["epoch"] + 1
         
     | 
| 314 | 
         
            +
                    iters = state["iters"]
         
     | 
| 315 | 
         
            +
                    optimizer.load_state_dict(state["optimizer"])
         
     | 
| 316 | 
         
            +
                    optimizer.load_scheduler_state_dict(state["scheduler"])
         
     | 
| 317 | 
         
            +
             
     | 
| 318 | 
         
            +
                else:
         
     | 
| 319 | 
         
            +
                    epoch = state["epoch"] + 1
         
     | 
| 320 | 
         
            +
                    iters = state["iters"]
         
     | 
| 321 | 
         
            +
             
     | 
| 322 | 
         
            +
                return model, optimizer, epoch, iters
         
     | 
| 323 | 
         
            +
             
     | 
| 324 | 
         
            +
             
     | 
| 325 | 
         
            +
            def recursive_munch(d):
         
     | 
| 326 | 
         
            +
                if isinstance(d, dict):
         
     | 
| 327 | 
         
            +
                    return Munch((k, recursive_munch(v)) for k, v in d.items())
         
     | 
| 328 | 
         
            +
                elif isinstance(d, list):
         
     | 
| 329 | 
         
            +
                    return [recursive_munch(v) for v in d]
         
     | 
| 330 | 
         
            +
                else:
         
     | 
| 331 | 
         
            +
                    return d
         
     | 
    	
        models/codec/facodec/modules/gradient_reversal.py
    ADDED
    
    | 
         @@ -0,0 +1,35 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2023 Amphion.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            from torch.autograd import Function
         
     | 
| 7 | 
         
            +
            import torch
         
     | 
| 8 | 
         
            +
            from torch import nn
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            class GradientReversal(Function):
         
     | 
| 12 | 
         
            +
                @staticmethod
         
     | 
| 13 | 
         
            +
                def forward(ctx, x, alpha):
         
     | 
| 14 | 
         
            +
                    ctx.save_for_backward(x, alpha)
         
     | 
| 15 | 
         
            +
                    return x
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
                @staticmethod
         
     | 
| 18 | 
         
            +
                def backward(ctx, grad_output):
         
     | 
| 19 | 
         
            +
                    grad_input = None
         
     | 
| 20 | 
         
            +
                    _, alpha = ctx.saved_tensors
         
     | 
| 21 | 
         
            +
                    if ctx.needs_input_grad[0]:
         
     | 
| 22 | 
         
            +
                        grad_input = -alpha * grad_output
         
     | 
| 23 | 
         
            +
                    return grad_input, None
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            revgrad = GradientReversal.apply
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            class GradientReversal(nn.Module):
         
     | 
| 30 | 
         
            +
                def __init__(self, alpha):
         
     | 
| 31 | 
         
            +
                    super().__init__()
         
     | 
| 32 | 
         
            +
                    self.alpha = torch.tensor(alpha, requires_grad=False)
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                def forward(self, x):
         
     | 
| 35 | 
         
            +
                    return revgrad(x, self.alpha)
         
     | 
    	
        models/codec/facodec/modules/layers.py
    ADDED
    
    | 
         @@ -0,0 +1,460 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2023 Amphion.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import math
         
     | 
| 7 | 
         
            +
            import torch
         
     | 
| 8 | 
         
            +
            from torch import nn
         
     | 
| 9 | 
         
            +
            from typing import Optional, Any
         
     | 
| 10 | 
         
            +
            from torch import Tensor
         
     | 
| 11 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 12 | 
         
            +
            import torchaudio
         
     | 
| 13 | 
         
            +
            import torchaudio.functional as audio_F
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            import random
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            random.seed(0)
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            def _get_activation_fn(activ):
         
     | 
| 21 | 
         
            +
                if activ == "relu":
         
     | 
| 22 | 
         
            +
                    return nn.ReLU()
         
     | 
| 23 | 
         
            +
                elif activ == "lrelu":
         
     | 
| 24 | 
         
            +
                    return nn.LeakyReLU(0.2)
         
     | 
| 25 | 
         
            +
                elif activ == "swish":
         
     | 
| 26 | 
         
            +
                    return lambda x: x * torch.sigmoid(x)
         
     | 
| 27 | 
         
            +
                else:
         
     | 
| 28 | 
         
            +
                    raise RuntimeError(
         
     | 
| 29 | 
         
            +
                        "Unexpected activ type %s, expected [relu, lrelu, swish]" % activ
         
     | 
| 30 | 
         
            +
                    )
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
            class LinearNorm(torch.nn.Module):
         
     | 
| 34 | 
         
            +
                def __init__(self, in_dim, out_dim, bias=True, w_init_gain="linear"):
         
     | 
| 35 | 
         
            +
                    super(LinearNorm, self).__init__()
         
     | 
| 36 | 
         
            +
                    self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                    torch.nn.init.xavier_uniform_(
         
     | 
| 39 | 
         
            +
                        self.linear_layer.weight, gain=torch.nn.init.calculate_gain(w_init_gain)
         
     | 
| 40 | 
         
            +
                    )
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                def forward(self, x):
         
     | 
| 43 | 
         
            +
                    return self.linear_layer(x)
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
            class ConvNorm(torch.nn.Module):
         
     | 
| 47 | 
         
            +
                def __init__(
         
     | 
| 48 | 
         
            +
                    self,
         
     | 
| 49 | 
         
            +
                    in_channels,
         
     | 
| 50 | 
         
            +
                    out_channels,
         
     | 
| 51 | 
         
            +
                    kernel_size=1,
         
     | 
| 52 | 
         
            +
                    stride=1,
         
     | 
| 53 | 
         
            +
                    padding=None,
         
     | 
| 54 | 
         
            +
                    dilation=1,
         
     | 
| 55 | 
         
            +
                    bias=True,
         
     | 
| 56 | 
         
            +
                    w_init_gain="linear",
         
     | 
| 57 | 
         
            +
                    param=None,
         
     | 
| 58 | 
         
            +
                ):
         
     | 
| 59 | 
         
            +
                    super(ConvNorm, self).__init__()
         
     | 
| 60 | 
         
            +
                    if padding is None:
         
     | 
| 61 | 
         
            +
                        assert kernel_size % 2 == 1
         
     | 
| 62 | 
         
            +
                        padding = int(dilation * (kernel_size - 1) / 2)
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                    self.conv = torch.nn.Conv1d(
         
     | 
| 65 | 
         
            +
                        in_channels,
         
     | 
| 66 | 
         
            +
                        out_channels,
         
     | 
| 67 | 
         
            +
                        kernel_size=kernel_size,
         
     | 
| 68 | 
         
            +
                        stride=stride,
         
     | 
| 69 | 
         
            +
                        padding=padding,
         
     | 
| 70 | 
         
            +
                        dilation=dilation,
         
     | 
| 71 | 
         
            +
                        bias=bias,
         
     | 
| 72 | 
         
            +
                    )
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                    torch.nn.init.xavier_uniform_(
         
     | 
| 75 | 
         
            +
                        self.conv.weight,
         
     | 
| 76 | 
         
            +
                        gain=torch.nn.init.calculate_gain(w_init_gain, param=param),
         
     | 
| 77 | 
         
            +
                    )
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
                def forward(self, signal):
         
     | 
| 80 | 
         
            +
                    conv_signal = self.conv(signal)
         
     | 
| 81 | 
         
            +
                    return conv_signal
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
            class CausualConv(nn.Module):
         
     | 
| 85 | 
         
            +
                def __init__(
         
     | 
| 86 | 
         
            +
                    self,
         
     | 
| 87 | 
         
            +
                    in_channels,
         
     | 
| 88 | 
         
            +
                    out_channels,
         
     | 
| 89 | 
         
            +
                    kernel_size=1,
         
     | 
| 90 | 
         
            +
                    stride=1,
         
     | 
| 91 | 
         
            +
                    padding=1,
         
     | 
| 92 | 
         
            +
                    dilation=1,
         
     | 
| 93 | 
         
            +
                    bias=True,
         
     | 
| 94 | 
         
            +
                    w_init_gain="linear",
         
     | 
| 95 | 
         
            +
                    param=None,
         
     | 
| 96 | 
         
            +
                ):
         
     | 
| 97 | 
         
            +
                    super(CausualConv, self).__init__()
         
     | 
| 98 | 
         
            +
                    if padding is None:
         
     | 
| 99 | 
         
            +
                        assert kernel_size % 2 == 1
         
     | 
| 100 | 
         
            +
                        padding = int(dilation * (kernel_size - 1) / 2) * 2
         
     | 
| 101 | 
         
            +
                    else:
         
     | 
| 102 | 
         
            +
                        self.padding = padding * 2
         
     | 
| 103 | 
         
            +
                    self.conv = nn.Conv1d(
         
     | 
| 104 | 
         
            +
                        in_channels,
         
     | 
| 105 | 
         
            +
                        out_channels,
         
     | 
| 106 | 
         
            +
                        kernel_size=kernel_size,
         
     | 
| 107 | 
         
            +
                        stride=stride,
         
     | 
| 108 | 
         
            +
                        padding=self.padding,
         
     | 
| 109 | 
         
            +
                        dilation=dilation,
         
     | 
| 110 | 
         
            +
                        bias=bias,
         
     | 
| 111 | 
         
            +
                    )
         
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
                    torch.nn.init.xavier_uniform_(
         
     | 
| 114 | 
         
            +
                        self.conv.weight,
         
     | 
| 115 | 
         
            +
                        gain=torch.nn.init.calculate_gain(w_init_gain, param=param),
         
     | 
| 116 | 
         
            +
                    )
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
                def forward(self, x):
         
     | 
| 119 | 
         
            +
                    x = self.conv(x)
         
     | 
| 120 | 
         
            +
                    x = x[:, :, : -self.padding]
         
     | 
| 121 | 
         
            +
                    return x
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
            class CausualBlock(nn.Module):
         
     | 
| 125 | 
         
            +
                def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ="lrelu"):
         
     | 
| 126 | 
         
            +
                    super(CausualBlock, self).__init__()
         
     | 
| 127 | 
         
            +
                    self.blocks = nn.ModuleList(
         
     | 
| 128 | 
         
            +
                        [
         
     | 
| 129 | 
         
            +
                            self._get_conv(
         
     | 
| 130 | 
         
            +
                                hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p
         
     | 
| 131 | 
         
            +
                            )
         
     | 
| 132 | 
         
            +
                            for i in range(n_conv)
         
     | 
| 133 | 
         
            +
                        ]
         
     | 
| 134 | 
         
            +
                    )
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
                def forward(self, x):
         
     | 
| 137 | 
         
            +
                    for block in self.blocks:
         
     | 
| 138 | 
         
            +
                        res = x
         
     | 
| 139 | 
         
            +
                        x = block(x)
         
     | 
| 140 | 
         
            +
                        x += res
         
     | 
| 141 | 
         
            +
                    return x
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
                def _get_conv(self, hidden_dim, dilation, activ="lrelu", dropout_p=0.2):
         
     | 
| 144 | 
         
            +
                    layers = [
         
     | 
| 145 | 
         
            +
                        CausualConv(
         
     | 
| 146 | 
         
            +
                            hidden_dim,
         
     | 
| 147 | 
         
            +
                            hidden_dim,
         
     | 
| 148 | 
         
            +
                            kernel_size=3,
         
     | 
| 149 | 
         
            +
                            padding=dilation,
         
     | 
| 150 | 
         
            +
                            dilation=dilation,
         
     | 
| 151 | 
         
            +
                        ),
         
     | 
| 152 | 
         
            +
                        _get_activation_fn(activ),
         
     | 
| 153 | 
         
            +
                        nn.BatchNorm1d(hidden_dim),
         
     | 
| 154 | 
         
            +
                        nn.Dropout(p=dropout_p),
         
     | 
| 155 | 
         
            +
                        CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
         
     | 
| 156 | 
         
            +
                        _get_activation_fn(activ),
         
     | 
| 157 | 
         
            +
                        nn.Dropout(p=dropout_p),
         
     | 
| 158 | 
         
            +
                    ]
         
     | 
| 159 | 
         
            +
                    return nn.Sequential(*layers)
         
     | 
| 160 | 
         
            +
             
     | 
| 161 | 
         
            +
             
     | 
| 162 | 
         
            +
            class ConvBlock(nn.Module):
         
     | 
| 163 | 
         
            +
                def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ="relu"):
         
     | 
| 164 | 
         
            +
                    super().__init__()
         
     | 
| 165 | 
         
            +
                    self._n_groups = 8
         
     | 
| 166 | 
         
            +
                    self.blocks = nn.ModuleList(
         
     | 
| 167 | 
         
            +
                        [
         
     | 
| 168 | 
         
            +
                            self._get_conv(
         
     | 
| 169 | 
         
            +
                                hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p
         
     | 
| 170 | 
         
            +
                            )
         
     | 
| 171 | 
         
            +
                            for i in range(n_conv)
         
     | 
| 172 | 
         
            +
                        ]
         
     | 
| 173 | 
         
            +
                    )
         
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
            +
                def forward(self, x):
         
     | 
| 176 | 
         
            +
                    for block in self.blocks:
         
     | 
| 177 | 
         
            +
                        res = x
         
     | 
| 178 | 
         
            +
                        x = block(x)
         
     | 
| 179 | 
         
            +
                        x += res
         
     | 
| 180 | 
         
            +
                    return x
         
     | 
| 181 | 
         
            +
             
     | 
| 182 | 
         
            +
                def _get_conv(self, hidden_dim, dilation, activ="relu", dropout_p=0.2):
         
     | 
| 183 | 
         
            +
                    layers = [
         
     | 
| 184 | 
         
            +
                        ConvNorm(
         
     | 
| 185 | 
         
            +
                            hidden_dim,
         
     | 
| 186 | 
         
            +
                            hidden_dim,
         
     | 
| 187 | 
         
            +
                            kernel_size=3,
         
     | 
| 188 | 
         
            +
                            padding=dilation,
         
     | 
| 189 | 
         
            +
                            dilation=dilation,
         
     | 
| 190 | 
         
            +
                        ),
         
     | 
| 191 | 
         
            +
                        _get_activation_fn(activ),
         
     | 
| 192 | 
         
            +
                        nn.GroupNorm(num_groups=self._n_groups, num_channels=hidden_dim),
         
     | 
| 193 | 
         
            +
                        nn.Dropout(p=dropout_p),
         
     | 
| 194 | 
         
            +
                        ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
         
     | 
| 195 | 
         
            +
                        _get_activation_fn(activ),
         
     | 
| 196 | 
         
            +
                        nn.Dropout(p=dropout_p),
         
     | 
| 197 | 
         
            +
                    ]
         
     | 
| 198 | 
         
            +
                    return nn.Sequential(*layers)
         
     | 
| 199 | 
         
            +
             
     | 
| 200 | 
         
            +
             
     | 
| 201 | 
         
            +
            class LocationLayer(nn.Module):
         
     | 
| 202 | 
         
            +
                def __init__(self, attention_n_filters, attention_kernel_size, attention_dim):
         
     | 
| 203 | 
         
            +
                    super(LocationLayer, self).__init__()
         
     | 
| 204 | 
         
            +
                    padding = int((attention_kernel_size - 1) / 2)
         
     | 
| 205 | 
         
            +
                    self.location_conv = ConvNorm(
         
     | 
| 206 | 
         
            +
                        2,
         
     | 
| 207 | 
         
            +
                        attention_n_filters,
         
     | 
| 208 | 
         
            +
                        kernel_size=attention_kernel_size,
         
     | 
| 209 | 
         
            +
                        padding=padding,
         
     | 
| 210 | 
         
            +
                        bias=False,
         
     | 
| 211 | 
         
            +
                        stride=1,
         
     | 
| 212 | 
         
            +
                        dilation=1,
         
     | 
| 213 | 
         
            +
                    )
         
     | 
| 214 | 
         
            +
                    self.location_dense = LinearNorm(
         
     | 
| 215 | 
         
            +
                        attention_n_filters, attention_dim, bias=False, w_init_gain="tanh"
         
     | 
| 216 | 
         
            +
                    )
         
     | 
| 217 | 
         
            +
             
     | 
| 218 | 
         
            +
                def forward(self, attention_weights_cat):
         
     | 
| 219 | 
         
            +
                    processed_attention = self.location_conv(attention_weights_cat)
         
     | 
| 220 | 
         
            +
                    processed_attention = processed_attention.transpose(1, 2)
         
     | 
| 221 | 
         
            +
                    processed_attention = self.location_dense(processed_attention)
         
     | 
| 222 | 
         
            +
                    return processed_attention
         
     | 
| 223 | 
         
            +
             
     | 
| 224 | 
         
            +
             
     | 
| 225 | 
         
            +
            class Attention(nn.Module):
         
     | 
| 226 | 
         
            +
                def __init__(
         
     | 
| 227 | 
         
            +
                    self,
         
     | 
| 228 | 
         
            +
                    attention_rnn_dim,
         
     | 
| 229 | 
         
            +
                    embedding_dim,
         
     | 
| 230 | 
         
            +
                    attention_dim,
         
     | 
| 231 | 
         
            +
                    attention_location_n_filters,
         
     | 
| 232 | 
         
            +
                    attention_location_kernel_size,
         
     | 
| 233 | 
         
            +
                ):
         
     | 
| 234 | 
         
            +
                    super(Attention, self).__init__()
         
     | 
| 235 | 
         
            +
                    self.query_layer = LinearNorm(
         
     | 
| 236 | 
         
            +
                        attention_rnn_dim, attention_dim, bias=False, w_init_gain="tanh"
         
     | 
| 237 | 
         
            +
                    )
         
     | 
| 238 | 
         
            +
                    self.memory_layer = LinearNorm(
         
     | 
| 239 | 
         
            +
                        embedding_dim, attention_dim, bias=False, w_init_gain="tanh"
         
     | 
| 240 | 
         
            +
                    )
         
     | 
| 241 | 
         
            +
                    self.v = LinearNorm(attention_dim, 1, bias=False)
         
     | 
| 242 | 
         
            +
                    self.location_layer = LocationLayer(
         
     | 
| 243 | 
         
            +
                        attention_location_n_filters, attention_location_kernel_size, attention_dim
         
     | 
| 244 | 
         
            +
                    )
         
     | 
| 245 | 
         
            +
                    self.score_mask_value = -float("inf")
         
     | 
| 246 | 
         
            +
             
     | 
| 247 | 
         
            +
                def get_alignment_energies(self, query, processed_memory, attention_weights_cat):
         
     | 
| 248 | 
         
            +
                    """
         
     | 
| 249 | 
         
            +
                    PARAMS
         
     | 
| 250 | 
         
            +
                    ------
         
     | 
| 251 | 
         
            +
                    query: decoder output (batch, n_mel_channels * n_frames_per_step)
         
     | 
| 252 | 
         
            +
                    processed_memory: processed encoder outputs (B, T_in, attention_dim)
         
     | 
| 253 | 
         
            +
                    attention_weights_cat: cumulative and prev. att weights (B, 2, max_time)
         
     | 
| 254 | 
         
            +
                    RETURNS
         
     | 
| 255 | 
         
            +
                    -------
         
     | 
| 256 | 
         
            +
                    alignment (batch, max_time)
         
     | 
| 257 | 
         
            +
                    """
         
     | 
| 258 | 
         
            +
             
     | 
| 259 | 
         
            +
                    processed_query = self.query_layer(query.unsqueeze(1))
         
     | 
| 260 | 
         
            +
                    processed_attention_weights = self.location_layer(attention_weights_cat)
         
     | 
| 261 | 
         
            +
                    energies = self.v(
         
     | 
| 262 | 
         
            +
                        torch.tanh(processed_query + processed_attention_weights + processed_memory)
         
     | 
| 263 | 
         
            +
                    )
         
     | 
| 264 | 
         
            +
             
     | 
| 265 | 
         
            +
                    energies = energies.squeeze(-1)
         
     | 
| 266 | 
         
            +
                    return energies
         
     | 
| 267 | 
         
            +
             
     | 
| 268 | 
         
            +
                def forward(
         
     | 
| 269 | 
         
            +
                    self,
         
     | 
| 270 | 
         
            +
                    attention_hidden_state,
         
     | 
| 271 | 
         
            +
                    memory,
         
     | 
| 272 | 
         
            +
                    processed_memory,
         
     | 
| 273 | 
         
            +
                    attention_weights_cat,
         
     | 
| 274 | 
         
            +
                    mask,
         
     | 
| 275 | 
         
            +
                ):
         
     | 
| 276 | 
         
            +
                    """
         
     | 
| 277 | 
         
            +
                    PARAMS
         
     | 
| 278 | 
         
            +
                    ------
         
     | 
| 279 | 
         
            +
                    attention_hidden_state: attention rnn last output
         
     | 
| 280 | 
         
            +
                    memory: encoder outputs
         
     | 
| 281 | 
         
            +
                    processed_memory: processed encoder outputs
         
     | 
| 282 | 
         
            +
                    attention_weights_cat: previous and cummulative attention weights
         
     | 
| 283 | 
         
            +
                    mask: binary mask for padded data
         
     | 
| 284 | 
         
            +
                    """
         
     | 
| 285 | 
         
            +
                    alignment = self.get_alignment_energies(
         
     | 
| 286 | 
         
            +
                        attention_hidden_state, processed_memory, attention_weights_cat
         
     | 
| 287 | 
         
            +
                    )
         
     | 
| 288 | 
         
            +
             
     | 
| 289 | 
         
            +
                    if mask is not None:
         
     | 
| 290 | 
         
            +
                        alignment.data.masked_fill_(mask, self.score_mask_value)
         
     | 
| 291 | 
         
            +
             
     | 
| 292 | 
         
            +
                    attention_weights = F.softmax(alignment, dim=1)
         
     | 
| 293 | 
         
            +
                    attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
         
     | 
| 294 | 
         
            +
                    attention_context = attention_context.squeeze(1)
         
     | 
| 295 | 
         
            +
             
     | 
| 296 | 
         
            +
                    return attention_context, attention_weights
         
     | 
| 297 | 
         
            +
             
     | 
| 298 | 
         
            +
             
     | 
| 299 | 
         
            +
            class ForwardAttentionV2(nn.Module):
         
     | 
| 300 | 
         
            +
                def __init__(
         
     | 
| 301 | 
         
            +
                    self,
         
     | 
| 302 | 
         
            +
                    attention_rnn_dim,
         
     | 
| 303 | 
         
            +
                    embedding_dim,
         
     | 
| 304 | 
         
            +
                    attention_dim,
         
     | 
| 305 | 
         
            +
                    attention_location_n_filters,
         
     | 
| 306 | 
         
            +
                    attention_location_kernel_size,
         
     | 
| 307 | 
         
            +
                ):
         
     | 
| 308 | 
         
            +
                    super(ForwardAttentionV2, self).__init__()
         
     | 
| 309 | 
         
            +
                    self.query_layer = LinearNorm(
         
     | 
| 310 | 
         
            +
                        attention_rnn_dim, attention_dim, bias=False, w_init_gain="tanh"
         
     | 
| 311 | 
         
            +
                    )
         
     | 
| 312 | 
         
            +
                    self.memory_layer = LinearNorm(
         
     | 
| 313 | 
         
            +
                        embedding_dim, attention_dim, bias=False, w_init_gain="tanh"
         
     | 
| 314 | 
         
            +
                    )
         
     | 
| 315 | 
         
            +
                    self.v = LinearNorm(attention_dim, 1, bias=False)
         
     | 
| 316 | 
         
            +
                    self.location_layer = LocationLayer(
         
     | 
| 317 | 
         
            +
                        attention_location_n_filters, attention_location_kernel_size, attention_dim
         
     | 
| 318 | 
         
            +
                    )
         
     | 
| 319 | 
         
            +
                    self.score_mask_value = -float(1e20)
         
     | 
| 320 | 
         
            +
             
     | 
| 321 | 
         
            +
                def get_alignment_energies(self, query, processed_memory, attention_weights_cat):
         
     | 
| 322 | 
         
            +
                    """
         
     | 
| 323 | 
         
            +
                    PARAMS
         
     | 
| 324 | 
         
            +
                    ------
         
     | 
| 325 | 
         
            +
                    query: decoder output (batch, n_mel_channels * n_frames_per_step)
         
     | 
| 326 | 
         
            +
                    processed_memory: processed encoder outputs (B, T_in, attention_dim)
         
     | 
| 327 | 
         
            +
                    attention_weights_cat:  prev. and cumulative att weights (B, 2, max_time)
         
     | 
| 328 | 
         
            +
                    RETURNS
         
     | 
| 329 | 
         
            +
                    -------
         
     | 
| 330 | 
         
            +
                    alignment (batch, max_time)
         
     | 
| 331 | 
         
            +
                    """
         
     | 
| 332 | 
         
            +
             
     | 
| 333 | 
         
            +
                    processed_query = self.query_layer(query.unsqueeze(1))
         
     | 
| 334 | 
         
            +
                    processed_attention_weights = self.location_layer(attention_weights_cat)
         
     | 
| 335 | 
         
            +
                    energies = self.v(
         
     | 
| 336 | 
         
            +
                        torch.tanh(processed_query + processed_attention_weights + processed_memory)
         
     | 
| 337 | 
         
            +
                    )
         
     | 
| 338 | 
         
            +
             
     | 
| 339 | 
         
            +
                    energies = energies.squeeze(-1)
         
     | 
| 340 | 
         
            +
                    return energies
         
     | 
| 341 | 
         
            +
             
     | 
| 342 | 
         
            +
                def forward(
         
     | 
| 343 | 
         
            +
                    self,
         
     | 
| 344 | 
         
            +
                    attention_hidden_state,
         
     | 
| 345 | 
         
            +
                    memory,
         
     | 
| 346 | 
         
            +
                    processed_memory,
         
     | 
| 347 | 
         
            +
                    attention_weights_cat,
         
     | 
| 348 | 
         
            +
                    mask,
         
     | 
| 349 | 
         
            +
                    log_alpha,
         
     | 
| 350 | 
         
            +
                ):
         
     | 
| 351 | 
         
            +
                    """
         
     | 
| 352 | 
         
            +
                    PARAMS
         
     | 
| 353 | 
         
            +
                    ------
         
     | 
| 354 | 
         
            +
                    attention_hidden_state: attention rnn last output
         
     | 
| 355 | 
         
            +
                    memory: encoder outputs
         
     | 
| 356 | 
         
            +
                    processed_memory: processed encoder outputs
         
     | 
| 357 | 
         
            +
                    attention_weights_cat: previous and cummulative attention weights
         
     | 
| 358 | 
         
            +
                    mask: binary mask for padded data
         
     | 
| 359 | 
         
            +
                    """
         
     | 
| 360 | 
         
            +
                    log_energy = self.get_alignment_energies(
         
     | 
| 361 | 
         
            +
                        attention_hidden_state, processed_memory, attention_weights_cat
         
     | 
| 362 | 
         
            +
                    )
         
     | 
| 363 | 
         
            +
             
     | 
| 364 | 
         
            +
                    # log_energy =
         
     | 
| 365 | 
         
            +
             
     | 
| 366 | 
         
            +
                    if mask is not None:
         
     | 
| 367 | 
         
            +
                        log_energy.data.masked_fill_(mask, self.score_mask_value)
         
     | 
| 368 | 
         
            +
             
     | 
| 369 | 
         
            +
                    # attention_weights = F.softmax(alignment, dim=1)
         
     | 
| 370 | 
         
            +
             
     | 
| 371 | 
         
            +
                    # content_score = log_energy.unsqueeze(1) #[B, MAX_TIME] -> [B, 1, MAX_TIME]
         
     | 
| 372 | 
         
            +
                    # log_alpha = log_alpha.unsqueeze(2) #[B, MAX_TIME] -> [B, MAX_TIME, 1]
         
     | 
| 373 | 
         
            +
             
     | 
| 374 | 
         
            +
                    # log_total_score = log_alpha + content_score
         
     | 
| 375 | 
         
            +
             
     | 
| 376 | 
         
            +
                    # previous_attention_weights = attention_weights_cat[:,0,:]
         
     | 
| 377 | 
         
            +
             
     | 
| 378 | 
         
            +
                    log_alpha_shift_padded = []
         
     | 
| 379 | 
         
            +
                    max_time = log_energy.size(1)
         
     | 
| 380 | 
         
            +
                    for sft in range(2):
         
     | 
| 381 | 
         
            +
                        shifted = log_alpha[:, : max_time - sft]
         
     | 
| 382 | 
         
            +
                        shift_padded = F.pad(shifted, (sft, 0), "constant", self.score_mask_value)
         
     | 
| 383 | 
         
            +
                        log_alpha_shift_padded.append(shift_padded.unsqueeze(2))
         
     | 
| 384 | 
         
            +
             
     | 
| 385 | 
         
            +
                    biased = torch.logsumexp(torch.cat(log_alpha_shift_padded, 2), 2)
         
     | 
| 386 | 
         
            +
             
     | 
| 387 | 
         
            +
                    log_alpha_new = biased + log_energy
         
     | 
| 388 | 
         
            +
             
     | 
| 389 | 
         
            +
                    attention_weights = F.softmax(log_alpha_new, dim=1)
         
     | 
| 390 | 
         
            +
             
     | 
| 391 | 
         
            +
                    attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
         
     | 
| 392 | 
         
            +
                    attention_context = attention_context.squeeze(1)
         
     | 
| 393 | 
         
            +
             
     | 
| 394 | 
         
            +
                    return attention_context, attention_weights, log_alpha_new
         
     | 
| 395 | 
         
            +
             
     | 
| 396 | 
         
            +
             
     | 
| 397 | 
         
            +
            class PhaseShuffle2d(nn.Module):
         
     | 
| 398 | 
         
            +
                def __init__(self, n=2):
         
     | 
| 399 | 
         
            +
                    super(PhaseShuffle2d, self).__init__()
         
     | 
| 400 | 
         
            +
                    self.n = n
         
     | 
| 401 | 
         
            +
                    self.random = random.Random(1)
         
     | 
| 402 | 
         
            +
             
     | 
| 403 | 
         
            +
                def forward(self, x, move=None):
         
     | 
| 404 | 
         
            +
                    # x.size = (B, C, M, L)
         
     | 
| 405 | 
         
            +
                    if move is None:
         
     | 
| 406 | 
         
            +
                        move = self.random.randint(-self.n, self.n)
         
     | 
| 407 | 
         
            +
             
     | 
| 408 | 
         
            +
                    if move == 0:
         
     | 
| 409 | 
         
            +
                        return x
         
     | 
| 410 | 
         
            +
                    else:
         
     | 
| 411 | 
         
            +
                        left = x[:, :, :, :move]
         
     | 
| 412 | 
         
            +
                        right = x[:, :, :, move:]
         
     | 
| 413 | 
         
            +
                        shuffled = torch.cat([right, left], dim=3)
         
     | 
| 414 | 
         
            +
                    return shuffled
         
     | 
| 415 | 
         
            +
             
     | 
| 416 | 
         
            +
             
     | 
| 417 | 
         
            +
            class PhaseShuffle1d(nn.Module):
         
     | 
| 418 | 
         
            +
                def __init__(self, n=2):
         
     | 
| 419 | 
         
            +
                    super(PhaseShuffle1d, self).__init__()
         
     | 
| 420 | 
         
            +
                    self.n = n
         
     | 
| 421 | 
         
            +
                    self.random = random.Random(1)
         
     | 
| 422 | 
         
            +
             
     | 
| 423 | 
         
            +
                def forward(self, x, move=None):
         
     | 
| 424 | 
         
            +
                    # x.size = (B, C, M, L)
         
     | 
| 425 | 
         
            +
                    if move is None:
         
     | 
| 426 | 
         
            +
                        move = self.random.randint(-self.n, self.n)
         
     | 
| 427 | 
         
            +
             
     | 
| 428 | 
         
            +
                    if move == 0:
         
     | 
| 429 | 
         
            +
                        return x
         
     | 
| 430 | 
         
            +
                    else:
         
     | 
| 431 | 
         
            +
                        left = x[:, :, :move]
         
     | 
| 432 | 
         
            +
                        right = x[:, :, move:]
         
     | 
| 433 | 
         
            +
                        shuffled = torch.cat([right, left], dim=2)
         
     | 
| 434 | 
         
            +
             
     | 
| 435 | 
         
            +
                    return shuffled
         
     | 
| 436 | 
         
            +
             
     | 
| 437 | 
         
            +
             
     | 
| 438 | 
         
            +
            class MFCC(nn.Module):
         
     | 
| 439 | 
         
            +
                def __init__(self, n_mfcc=40, n_mels=80):
         
     | 
| 440 | 
         
            +
                    super(MFCC, self).__init__()
         
     | 
| 441 | 
         
            +
                    self.n_mfcc = n_mfcc
         
     | 
| 442 | 
         
            +
                    self.n_mels = n_mels
         
     | 
| 443 | 
         
            +
                    self.norm = "ortho"
         
     | 
| 444 | 
         
            +
                    dct_mat = audio_F.create_dct(self.n_mfcc, self.n_mels, self.norm)
         
     | 
| 445 | 
         
            +
                    self.register_buffer("dct_mat", dct_mat)
         
     | 
| 446 | 
         
            +
             
     | 
| 447 | 
         
            +
                def forward(self, mel_specgram):
         
     | 
| 448 | 
         
            +
                    if len(mel_specgram.shape) == 2:
         
     | 
| 449 | 
         
            +
                        mel_specgram = mel_specgram.unsqueeze(0)
         
     | 
| 450 | 
         
            +
                        unsqueezed = True
         
     | 
| 451 | 
         
            +
                    else:
         
     | 
| 452 | 
         
            +
                        unsqueezed = False
         
     | 
| 453 | 
         
            +
                    # (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
         
     | 
| 454 | 
         
            +
                    # -> (channel, time, n_mfcc).tranpose(...)
         
     | 
| 455 | 
         
            +
                    mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
         
     | 
| 456 | 
         
            +
             
     | 
| 457 | 
         
            +
                    # unpack batch
         
     | 
| 458 | 
         
            +
                    if unsqueezed:
         
     | 
| 459 | 
         
            +
                        mfcc = mfcc.squeeze(0)
         
     | 
| 460 | 
         
            +
                    return mfcc
         
     | 
    	
        models/codec/facodec/modules/quantize.py
    ADDED
    
    | 
         @@ -0,0 +1,741 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2023 Amphion.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            from modules.dac.nn.quantize import ResidualVectorQuantize
         
     | 
| 7 | 
         
            +
            from torch import nn
         
     | 
| 8 | 
         
            +
            from .wavenet import WN
         
     | 
| 9 | 
         
            +
            from .style_encoder import StyleEncoder
         
     | 
| 10 | 
         
            +
            from .gradient_reversal import GradientReversal
         
     | 
| 11 | 
         
            +
            import torch
         
     | 
| 12 | 
         
            +
            import torchaudio
         
     | 
| 13 | 
         
            +
            import torchaudio.functional as audio_F
         
     | 
| 14 | 
         
            +
            import numpy as np
         
     | 
| 15 | 
         
            +
            from ..alias_free_torch import *
         
     | 
| 16 | 
         
            +
            from torch.nn.utils import weight_norm
         
     | 
| 17 | 
         
            +
            from torch import nn, sin, pow
         
     | 
| 18 | 
         
            +
            from einops.layers.torch import Rearrange
         
     | 
| 19 | 
         
            +
            from modules.dac.model.encodec import SConv1d
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            def init_weights(m):
         
     | 
| 23 | 
         
            +
                if isinstance(m, nn.Conv1d):
         
     | 
| 24 | 
         
            +
                    nn.init.trunc_normal_(m.weight, std=0.02)
         
     | 
| 25 | 
         
            +
                    nn.init.constant_(m.bias, 0)
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            def WNConv1d(*args, **kwargs):
         
     | 
| 29 | 
         
            +
                return weight_norm(nn.Conv1d(*args, **kwargs))
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
            def WNConvTranspose1d(*args, **kwargs):
         
     | 
| 33 | 
         
            +
                return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
            class SnakeBeta(nn.Module):
         
     | 
| 37 | 
         
            +
                """
         
     | 
| 38 | 
         
            +
                A modified Snake function which uses separate parameters for the magnitude of the periodic components
         
     | 
| 39 | 
         
            +
                Shape:
         
     | 
| 40 | 
         
            +
                    - Input: (B, C, T)
         
     | 
| 41 | 
         
            +
                    - Output: (B, C, T), same shape as the input
         
     | 
| 42 | 
         
            +
                Parameters:
         
     | 
| 43 | 
         
            +
                    - alpha - trainable parameter that controls frequency
         
     | 
| 44 | 
         
            +
                    - beta - trainable parameter that controls magnitude
         
     | 
| 45 | 
         
            +
                References:
         
     | 
| 46 | 
         
            +
                    - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
         
     | 
| 47 | 
         
            +
                    https://arxiv.org/abs/2006.08195
         
     | 
| 48 | 
         
            +
                Examples:
         
     | 
| 49 | 
         
            +
                    >>> a1 = snakebeta(256)
         
     | 
| 50 | 
         
            +
                    >>> x = torch.randn(256)
         
     | 
| 51 | 
         
            +
                    >>> x = a1(x)
         
     | 
| 52 | 
         
            +
                """
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                def __init__(
         
     | 
| 55 | 
         
            +
                    self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
         
     | 
| 56 | 
         
            +
                ):
         
     | 
| 57 | 
         
            +
                    """
         
     | 
| 58 | 
         
            +
                    Initialization.
         
     | 
| 59 | 
         
            +
                    INPUT:
         
     | 
| 60 | 
         
            +
                        - in_features: shape of the input
         
     | 
| 61 | 
         
            +
                        - alpha - trainable parameter that controls frequency
         
     | 
| 62 | 
         
            +
                        - beta - trainable parameter that controls magnitude
         
     | 
| 63 | 
         
            +
                        alpha is initialized to 1 by default, higher values = higher-frequency.
         
     | 
| 64 | 
         
            +
                        beta is initialized to 1 by default, higher values = higher-magnitude.
         
     | 
| 65 | 
         
            +
                        alpha will be trained along with the rest of your model.
         
     | 
| 66 | 
         
            +
                    """
         
     | 
| 67 | 
         
            +
                    super(SnakeBeta, self).__init__()
         
     | 
| 68 | 
         
            +
                    self.in_features = in_features
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                    # initialize alpha
         
     | 
| 71 | 
         
            +
                    self.alpha_logscale = alpha_logscale
         
     | 
| 72 | 
         
            +
                    if self.alpha_logscale:  # log scale alphas initialized to zeros
         
     | 
| 73 | 
         
            +
                        self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
         
     | 
| 74 | 
         
            +
                        self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
         
     | 
| 75 | 
         
            +
                    else:  # linear scale alphas initialized to ones
         
     | 
| 76 | 
         
            +
                        self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
         
     | 
| 77 | 
         
            +
                        self.beta = nn.Parameter(torch.ones(in_features) * alpha)
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
                    self.alpha.requires_grad = alpha_trainable
         
     | 
| 80 | 
         
            +
                    self.beta.requires_grad = alpha_trainable
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                    self.no_div_by_zero = 0.000000001
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
                def forward(self, x):
         
     | 
| 85 | 
         
            +
                    """
         
     | 
| 86 | 
         
            +
                    Forward pass of the function.
         
     | 
| 87 | 
         
            +
                    Applies the function to the input elementwise.
         
     | 
| 88 | 
         
            +
                    SnakeBeta := x + 1/b * sin^2 (xa)
         
     | 
| 89 | 
         
            +
                    """
         
     | 
| 90 | 
         
            +
                    alpha = self.alpha.unsqueeze(0).unsqueeze(-1)  # line up with x to [B, C, T]
         
     | 
| 91 | 
         
            +
                    beta = self.beta.unsqueeze(0).unsqueeze(-1)
         
     | 
| 92 | 
         
            +
                    if self.alpha_logscale:
         
     | 
| 93 | 
         
            +
                        alpha = torch.exp(alpha)
         
     | 
| 94 | 
         
            +
                        beta = torch.exp(beta)
         
     | 
| 95 | 
         
            +
                    x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                    return x
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
            class ResidualUnit(nn.Module):
         
     | 
| 101 | 
         
            +
                def __init__(self, dim: int = 16, dilation: int = 1):
         
     | 
| 102 | 
         
            +
                    super().__init__()
         
     | 
| 103 | 
         
            +
                    pad = ((7 - 1) * dilation) // 2
         
     | 
| 104 | 
         
            +
                    self.block = nn.Sequential(
         
     | 
| 105 | 
         
            +
                        Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)),
         
     | 
| 106 | 
         
            +
                        WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
         
     | 
| 107 | 
         
            +
                        Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)),
         
     | 
| 108 | 
         
            +
                        WNConv1d(dim, dim, kernel_size=1),
         
     | 
| 109 | 
         
            +
                    )
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                def forward(self, x):
         
     | 
| 112 | 
         
            +
                    return x + self.block(x)
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
            class CNNLSTM(nn.Module):
         
     | 
| 116 | 
         
            +
                def __init__(self, indim, outdim, head, global_pred=False):
         
     | 
| 117 | 
         
            +
                    super().__init__()
         
     | 
| 118 | 
         
            +
                    self.global_pred = global_pred
         
     | 
| 119 | 
         
            +
                    self.model = nn.Sequential(
         
     | 
| 120 | 
         
            +
                        ResidualUnit(indim, dilation=1),
         
     | 
| 121 | 
         
            +
                        ResidualUnit(indim, dilation=2),
         
     | 
| 122 | 
         
            +
                        ResidualUnit(indim, dilation=3),
         
     | 
| 123 | 
         
            +
                        Activation1d(activation=SnakeBeta(indim, alpha_logscale=True)),
         
     | 
| 124 | 
         
            +
                        Rearrange("b c t -> b t c"),
         
     | 
| 125 | 
         
            +
                    )
         
     | 
| 126 | 
         
            +
                    self.heads = nn.ModuleList([nn.Linear(indim, outdim) for i in range(head)])
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
                def forward(self, x):
         
     | 
| 129 | 
         
            +
                    # x: [B, C, T]
         
     | 
| 130 | 
         
            +
                    x = self.model(x)
         
     | 
| 131 | 
         
            +
                    if self.global_pred:
         
     | 
| 132 | 
         
            +
                        x = torch.mean(x, dim=1, keepdim=False)
         
     | 
| 133 | 
         
            +
                    outs = [head(x) for head in self.heads]
         
     | 
| 134 | 
         
            +
                    return outs
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
            def sequence_mask(length, max_length=None):
         
     | 
| 138 | 
         
            +
                if max_length is None:
         
     | 
| 139 | 
         
            +
                    max_length = length.max()
         
     | 
| 140 | 
         
            +
                x = torch.arange(max_length, dtype=length.dtype, device=length.device)
         
     | 
| 141 | 
         
            +
                return x.unsqueeze(0) < length.unsqueeze(1)
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
            class MFCC(nn.Module):
         
     | 
| 145 | 
         
            +
                def __init__(self, n_mfcc=40, n_mels=80):
         
     | 
| 146 | 
         
            +
                    super(MFCC, self).__init__()
         
     | 
| 147 | 
         
            +
                    self.n_mfcc = n_mfcc
         
     | 
| 148 | 
         
            +
                    self.n_mels = n_mels
         
     | 
| 149 | 
         
            +
                    self.norm = "ortho"
         
     | 
| 150 | 
         
            +
                    dct_mat = audio_F.create_dct(self.n_mfcc, self.n_mels, self.norm)
         
     | 
| 151 | 
         
            +
                    self.register_buffer("dct_mat", dct_mat)
         
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
                def forward(self, mel_specgram):
         
     | 
| 154 | 
         
            +
                    if len(mel_specgram.shape) == 2:
         
     | 
| 155 | 
         
            +
                        mel_specgram = mel_specgram.unsqueeze(0)
         
     | 
| 156 | 
         
            +
                        unsqueezed = True
         
     | 
| 157 | 
         
            +
                    else:
         
     | 
| 158 | 
         
            +
                        unsqueezed = False
         
     | 
| 159 | 
         
            +
                    # (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
         
     | 
| 160 | 
         
            +
                    # -> (channel, time, n_mfcc).tranpose(...)
         
     | 
| 161 | 
         
            +
                    mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
         
     | 
| 162 | 
         
            +
             
     | 
| 163 | 
         
            +
                    # unpack batch
         
     | 
| 164 | 
         
            +
                    if unsqueezed:
         
     | 
| 165 | 
         
            +
                        mfcc = mfcc.squeeze(0)
         
     | 
| 166 | 
         
            +
                    return mfcc
         
     | 
| 167 | 
         
            +
             
     | 
| 168 | 
         
            +
             
     | 
| 169 | 
         
            +
            class FAquantizer(nn.Module):
         
     | 
| 170 | 
         
            +
                def __init__(
         
     | 
| 171 | 
         
            +
                    self,
         
     | 
| 172 | 
         
            +
                    in_dim=1024,
         
     | 
| 173 | 
         
            +
                    n_p_codebooks=1,
         
     | 
| 174 | 
         
            +
                    n_c_codebooks=2,
         
     | 
| 175 | 
         
            +
                    n_t_codebooks=2,
         
     | 
| 176 | 
         
            +
                    n_r_codebooks=3,
         
     | 
| 177 | 
         
            +
                    codebook_size=1024,
         
     | 
| 178 | 
         
            +
                    codebook_dim=8,
         
     | 
| 179 | 
         
            +
                    quantizer_dropout=0.5,
         
     | 
| 180 | 
         
            +
                    causal=False,
         
     | 
| 181 | 
         
            +
                    separate_prosody_encoder=False,
         
     | 
| 182 | 
         
            +
                    timbre_norm=False,
         
     | 
| 183 | 
         
            +
                ):
         
     | 
| 184 | 
         
            +
                    super(FAquantizer, self).__init__()
         
     | 
| 185 | 
         
            +
                    conv1d_type = SConv1d  # if causal else nn.Conv1d
         
     | 
| 186 | 
         
            +
                    self.prosody_quantizer = ResidualVectorQuantize(
         
     | 
| 187 | 
         
            +
                        input_dim=in_dim,
         
     | 
| 188 | 
         
            +
                        n_codebooks=n_p_codebooks,
         
     | 
| 189 | 
         
            +
                        codebook_size=codebook_size,
         
     | 
| 190 | 
         
            +
                        codebook_dim=codebook_dim,
         
     | 
| 191 | 
         
            +
                        quantizer_dropout=quantizer_dropout,
         
     | 
| 192 | 
         
            +
                    )
         
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
                    self.content_quantizer = ResidualVectorQuantize(
         
     | 
| 195 | 
         
            +
                        input_dim=in_dim,
         
     | 
| 196 | 
         
            +
                        n_codebooks=n_c_codebooks,
         
     | 
| 197 | 
         
            +
                        codebook_size=codebook_size,
         
     | 
| 198 | 
         
            +
                        codebook_dim=codebook_dim,
         
     | 
| 199 | 
         
            +
                        quantizer_dropout=quantizer_dropout,
         
     | 
| 200 | 
         
            +
                    )
         
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
                    if not timbre_norm:
         
     | 
| 203 | 
         
            +
                        self.timbre_quantizer = ResidualVectorQuantize(
         
     | 
| 204 | 
         
            +
                            input_dim=in_dim,
         
     | 
| 205 | 
         
            +
                            n_codebooks=n_t_codebooks,
         
     | 
| 206 | 
         
            +
                            codebook_size=codebook_size,
         
     | 
| 207 | 
         
            +
                            codebook_dim=codebook_dim,
         
     | 
| 208 | 
         
            +
                            quantizer_dropout=quantizer_dropout,
         
     | 
| 209 | 
         
            +
                        )
         
     | 
| 210 | 
         
            +
                    else:
         
     | 
| 211 | 
         
            +
                        self.timbre_encoder = StyleEncoder(
         
     | 
| 212 | 
         
            +
                            in_dim=80, hidden_dim=512, out_dim=in_dim
         
     | 
| 213 | 
         
            +
                        )
         
     | 
| 214 | 
         
            +
                        self.timbre_linear = nn.Linear(1024, 1024 * 2)
         
     | 
| 215 | 
         
            +
                        self.timbre_linear.bias.data[:1024] = 1
         
     | 
| 216 | 
         
            +
                        self.timbre_linear.bias.data[1024:] = 0
         
     | 
| 217 | 
         
            +
                        self.timbre_norm = nn.LayerNorm(1024, elementwise_affine=False)
         
     | 
| 218 | 
         
            +
             
     | 
| 219 | 
         
            +
                    self.residual_quantizer = ResidualVectorQuantize(
         
     | 
| 220 | 
         
            +
                        input_dim=in_dim,
         
     | 
| 221 | 
         
            +
                        n_codebooks=n_r_codebooks,
         
     | 
| 222 | 
         
            +
                        codebook_size=codebook_size,
         
     | 
| 223 | 
         
            +
                        codebook_dim=codebook_dim,
         
     | 
| 224 | 
         
            +
                        quantizer_dropout=quantizer_dropout,
         
     | 
| 225 | 
         
            +
                    )
         
     | 
| 226 | 
         
            +
             
     | 
| 227 | 
         
            +
                    if separate_prosody_encoder:
         
     | 
| 228 | 
         
            +
                        self.melspec_linear = conv1d_type(
         
     | 
| 229 | 
         
            +
                            in_channels=20, out_channels=256, kernel_size=1, causal=causal
         
     | 
| 230 | 
         
            +
                        )
         
     | 
| 231 | 
         
            +
                        self.melspec_encoder = WN(
         
     | 
| 232 | 
         
            +
                            hidden_channels=256,
         
     | 
| 233 | 
         
            +
                            kernel_size=5,
         
     | 
| 234 | 
         
            +
                            dilation_rate=1,
         
     | 
| 235 | 
         
            +
                            n_layers=8,
         
     | 
| 236 | 
         
            +
                            gin_channels=0,
         
     | 
| 237 | 
         
            +
                            p_dropout=0.2,
         
     | 
| 238 | 
         
            +
                            causal=causal,
         
     | 
| 239 | 
         
            +
                        )
         
     | 
| 240 | 
         
            +
                        self.melspec_linear2 = conv1d_type(
         
     | 
| 241 | 
         
            +
                            in_channels=256, out_channels=1024, kernel_size=1, causal=causal
         
     | 
| 242 | 
         
            +
                        )
         
     | 
| 243 | 
         
            +
                    else:
         
     | 
| 244 | 
         
            +
                        pass
         
     | 
| 245 | 
         
            +
                    self.separate_prosody_encoder = separate_prosody_encoder
         
     | 
| 246 | 
         
            +
             
     | 
| 247 | 
         
            +
                    self.prob_random_mask_residual = 0.75
         
     | 
| 248 | 
         
            +
             
     | 
| 249 | 
         
            +
                    SPECT_PARAMS = {
         
     | 
| 250 | 
         
            +
                        "n_fft": 2048,
         
     | 
| 251 | 
         
            +
                        "win_length": 1200,
         
     | 
| 252 | 
         
            +
                        "hop_length": 300,
         
     | 
| 253 | 
         
            +
                    }
         
     | 
| 254 | 
         
            +
                    MEL_PARAMS = {
         
     | 
| 255 | 
         
            +
                        "n_mels": 80,
         
     | 
| 256 | 
         
            +
                    }
         
     | 
| 257 | 
         
            +
             
     | 
| 258 | 
         
            +
                    self.to_mel = torchaudio.transforms.MelSpectrogram(
         
     | 
| 259 | 
         
            +
                        n_mels=MEL_PARAMS["n_mels"], sample_rate=24000, **SPECT_PARAMS
         
     | 
| 260 | 
         
            +
                    )
         
     | 
| 261 | 
         
            +
                    self.mel_mean, self.mel_std = -4, 4
         
     | 
| 262 | 
         
            +
                    self.frame_rate = 24000 / 300
         
     | 
| 263 | 
         
            +
                    self.hop_length = 300
         
     | 
| 264 | 
         
            +
             
     | 
| 265 | 
         
            +
                    self.is_timbre_norm = timbre_norm
         
     | 
| 266 | 
         
            +
                    if timbre_norm:
         
     | 
| 267 | 
         
            +
                        self.forward = self.forward_v2
         
     | 
| 268 | 
         
            +
             
     | 
| 269 | 
         
            +
                def preprocess(self, wave_tensor, n_bins=20):
         
     | 
| 270 | 
         
            +
                    mel_tensor = self.to_mel(wave_tensor.squeeze(1))
         
     | 
| 271 | 
         
            +
                    mel_tensor = (torch.log(1e-5 + mel_tensor) - self.mel_mean) / self.mel_std
         
     | 
| 272 | 
         
            +
                    return mel_tensor[:, :n_bins, : int(wave_tensor.size(-1) / self.hop_length)]
         
     | 
| 273 | 
         
            +
             
     | 
| 274 | 
         
            +
                @torch.no_grad()
         
     | 
| 275 | 
         
            +
                def decode(self, codes):
         
     | 
| 276 | 
         
            +
                    code_c, code_p, code_t = codes.split([1, 1, 2], dim=1)
         
     | 
| 277 | 
         
            +
             
     | 
| 278 | 
         
            +
                    z_c = self.content_quantizer.from_codes(code_c)[0]
         
     | 
| 279 | 
         
            +
                    z_p = self.prosody_quantizer.from_codes(code_p)[0]
         
     | 
| 280 | 
         
            +
                    z_t = self.timbre_quantizer.from_codes(code_t)[0]
         
     | 
| 281 | 
         
            +
             
     | 
| 282 | 
         
            +
                    z = z_c + z_p + z_t
         
     | 
| 283 | 
         
            +
             
     | 
| 284 | 
         
            +
                    return z, [z_c, z_p, z_t]
         
     | 
| 285 | 
         
            +
             
     | 
| 286 | 
         
            +
                @torch.no_grad()
         
     | 
| 287 | 
         
            +
                def encode(self, x, wave_segments, n_c=1):
         
     | 
| 288 | 
         
            +
                    outs = 0
         
     | 
| 289 | 
         
            +
                    if self.separate_prosody_encoder:
         
     | 
| 290 | 
         
            +
                        prosody_feature = self.preprocess(wave_segments)
         
     | 
| 291 | 
         
            +
             
     | 
| 292 | 
         
            +
                        f0_input = prosody_feature  # (B, T, 20)
         
     | 
| 293 | 
         
            +
                        f0_input = self.melspec_linear(f0_input)
         
     | 
| 294 | 
         
            +
                        f0_input = self.melspec_encoder(
         
     | 
| 295 | 
         
            +
                            f0_input,
         
     | 
| 296 | 
         
            +
                            torch.ones(f0_input.shape[0], 1, f0_input.shape[2])
         
     | 
| 297 | 
         
            +
                            .to(f0_input.device)
         
     | 
| 298 | 
         
            +
                            .bool(),
         
     | 
| 299 | 
         
            +
                        )
         
     | 
| 300 | 
         
            +
                        f0_input = self.melspec_linear2(f0_input)
         
     | 
| 301 | 
         
            +
             
     | 
| 302 | 
         
            +
                        common_min_size = min(f0_input.size(2), x.size(2))
         
     | 
| 303 | 
         
            +
                        f0_input = f0_input[:, :, :common_min_size]
         
     | 
| 304 | 
         
            +
             
     | 
| 305 | 
         
            +
                        x = x[:, :, :common_min_size]
         
     | 
| 306 | 
         
            +
             
     | 
| 307 | 
         
            +
                        (
         
     | 
| 308 | 
         
            +
                            z_p,
         
     | 
| 309 | 
         
            +
                            codes_p,
         
     | 
| 310 | 
         
            +
                            latents_p,
         
     | 
| 311 | 
         
            +
                            commitment_loss_p,
         
     | 
| 312 | 
         
            +
                            codebook_loss_p,
         
     | 
| 313 | 
         
            +
                        ) = self.prosody_quantizer(f0_input, 1)
         
     | 
| 314 | 
         
            +
                        outs += z_p.detach()
         
     | 
| 315 | 
         
            +
                    else:
         
     | 
| 316 | 
         
            +
                        (
         
     | 
| 317 | 
         
            +
                            z_p,
         
     | 
| 318 | 
         
            +
                            codes_p,
         
     | 
| 319 | 
         
            +
                            latents_p,
         
     | 
| 320 | 
         
            +
                            commitment_loss_p,
         
     | 
| 321 | 
         
            +
                            codebook_loss_p,
         
     | 
| 322 | 
         
            +
                        ) = self.prosody_quantizer(x, 1)
         
     | 
| 323 | 
         
            +
                        outs += z_p.detach()
         
     | 
| 324 | 
         
            +
             
     | 
| 325 | 
         
            +
                    (
         
     | 
| 326 | 
         
            +
                        z_c,
         
     | 
| 327 | 
         
            +
                        codes_c,
         
     | 
| 328 | 
         
            +
                        latents_c,
         
     | 
| 329 | 
         
            +
                        commitment_loss_c,
         
     | 
| 330 | 
         
            +
                        codebook_loss_c,
         
     | 
| 331 | 
         
            +
                    ) = self.content_quantizer(x, n_c)
         
     | 
| 332 | 
         
            +
                    outs += z_c.detach()
         
     | 
| 333 | 
         
            +
             
     | 
| 334 | 
         
            +
                    timbre_residual_feature = x - z_p.detach() - z_c.detach()
         
     | 
| 335 | 
         
            +
             
     | 
| 336 | 
         
            +
                    (
         
     | 
| 337 | 
         
            +
                        z_t,
         
     | 
| 338 | 
         
            +
                        codes_t,
         
     | 
| 339 | 
         
            +
                        latents_t,
         
     | 
| 340 | 
         
            +
                        commitment_loss_t,
         
     | 
| 341 | 
         
            +
                        codebook_loss_t,
         
     | 
| 342 | 
         
            +
                    ) = self.timbre_quantizer(timbre_residual_feature, 2)
         
     | 
| 343 | 
         
            +
                    outs += z_t  # we should not detach timbre
         
     | 
| 344 | 
         
            +
             
     | 
| 345 | 
         
            +
                    residual_feature = timbre_residual_feature - z_t
         
     | 
| 346 | 
         
            +
             
     | 
| 347 | 
         
            +
                    (
         
     | 
| 348 | 
         
            +
                        z_r,
         
     | 
| 349 | 
         
            +
                        codes_r,
         
     | 
| 350 | 
         
            +
                        latents_r,
         
     | 
| 351 | 
         
            +
                        commitment_loss_r,
         
     | 
| 352 | 
         
            +
                        codebook_loss_r,
         
     | 
| 353 | 
         
            +
                    ) = self.residual_quantizer(residual_feature, 3)
         
     | 
| 354 | 
         
            +
             
     | 
| 355 | 
         
            +
                    return [codes_c, codes_p, codes_t, codes_r], [z_c, z_p, z_t, z_r]
         
     | 
| 356 | 
         
            +
             
     | 
| 357 | 
         
            +
                def forward(
         
     | 
| 358 | 
         
            +
                    self, x, wave_segments, noise_added_flags, recon_noisy_flags, n_c=2, n_t=2
         
     | 
| 359 | 
         
            +
                ):
         
     | 
| 360 | 
         
            +
                    # timbre = self.timbre_encoder(mels, sequence_mask(mel_lens, mels.size(-1)).unsqueeze(1))
         
     | 
| 361 | 
         
            +
                    # timbre = self.timbre_encoder(mel_segments, torch.ones(mel_segments.size(0), 1, mel_segments.size(2)).bool().to(mel_segments.device))
         
     | 
| 362 | 
         
            +
                    outs = 0
         
     | 
| 363 | 
         
            +
                    if self.separate_prosody_encoder:
         
     | 
| 364 | 
         
            +
                        prosody_feature = self.preprocess(wave_segments)
         
     | 
| 365 | 
         
            +
             
     | 
| 366 | 
         
            +
                        f0_input = prosody_feature  # (B, T, 20)
         
     | 
| 367 | 
         
            +
                        f0_input = self.melspec_linear(f0_input)
         
     | 
| 368 | 
         
            +
                        f0_input = self.melspec_encoder(
         
     | 
| 369 | 
         
            +
                            f0_input,
         
     | 
| 370 | 
         
            +
                            torch.ones(f0_input.shape[0], 1, f0_input.shape[2])
         
     | 
| 371 | 
         
            +
                            .to(f0_input.device)
         
     | 
| 372 | 
         
            +
                            .bool(),
         
     | 
| 373 | 
         
            +
                        )
         
     | 
| 374 | 
         
            +
                        f0_input = self.melspec_linear2(f0_input)
         
     | 
| 375 | 
         
            +
             
     | 
| 376 | 
         
            +
                        common_min_size = min(f0_input.size(2), x.size(2))
         
     | 
| 377 | 
         
            +
                        f0_input = f0_input[:, :, :common_min_size]
         
     | 
| 378 | 
         
            +
             
     | 
| 379 | 
         
            +
                        x = x[:, :, :common_min_size]
         
     | 
| 380 | 
         
            +
             
     | 
| 381 | 
         
            +
                        (
         
     | 
| 382 | 
         
            +
                            z_p,
         
     | 
| 383 | 
         
            +
                            codes_p,
         
     | 
| 384 | 
         
            +
                            latents_p,
         
     | 
| 385 | 
         
            +
                            commitment_loss_p,
         
     | 
| 386 | 
         
            +
                            codebook_loss_p,
         
     | 
| 387 | 
         
            +
                        ) = self.prosody_quantizer(f0_input, 1)
         
     | 
| 388 | 
         
            +
                        outs += z_p.detach()
         
     | 
| 389 | 
         
            +
                    else:
         
     | 
| 390 | 
         
            +
                        (
         
     | 
| 391 | 
         
            +
                            z_p,
         
     | 
| 392 | 
         
            +
                            codes_p,
         
     | 
| 393 | 
         
            +
                            latents_p,
         
     | 
| 394 | 
         
            +
                            commitment_loss_p,
         
     | 
| 395 | 
         
            +
                            codebook_loss_p,
         
     | 
| 396 | 
         
            +
                        ) = self.prosody_quantizer(x, 1)
         
     | 
| 397 | 
         
            +
                        outs += z_p.detach()
         
     | 
| 398 | 
         
            +
             
     | 
| 399 | 
         
            +
                    (
         
     | 
| 400 | 
         
            +
                        z_c,
         
     | 
| 401 | 
         
            +
                        codes_c,
         
     | 
| 402 | 
         
            +
                        latents_c,
         
     | 
| 403 | 
         
            +
                        commitment_loss_c,
         
     | 
| 404 | 
         
            +
                        codebook_loss_c,
         
     | 
| 405 | 
         
            +
                    ) = self.content_quantizer(x, n_c)
         
     | 
| 406 | 
         
            +
                    outs += z_c.detach()
         
     | 
| 407 | 
         
            +
             
     | 
| 408 | 
         
            +
                    timbre_residual_feature = x - z_p.detach() - z_c.detach()
         
     | 
| 409 | 
         
            +
             
     | 
| 410 | 
         
            +
                    (
         
     | 
| 411 | 
         
            +
                        z_t,
         
     | 
| 412 | 
         
            +
                        codes_t,
         
     | 
| 413 | 
         
            +
                        latents_t,
         
     | 
| 414 | 
         
            +
                        commitment_loss_t,
         
     | 
| 415 | 
         
            +
                        codebook_loss_t,
         
     | 
| 416 | 
         
            +
                    ) = self.timbre_quantizer(timbre_residual_feature, n_t)
         
     | 
| 417 | 
         
            +
                    outs += z_t  # we should not detach timbre
         
     | 
| 418 | 
         
            +
             
     | 
| 419 | 
         
            +
                    residual_feature = timbre_residual_feature - z_t
         
     | 
| 420 | 
         
            +
             
     | 
| 421 | 
         
            +
                    (
         
     | 
| 422 | 
         
            +
                        z_r,
         
     | 
| 423 | 
         
            +
                        codes_r,
         
     | 
| 424 | 
         
            +
                        latents_r,
         
     | 
| 425 | 
         
            +
                        commitment_loss_r,
         
     | 
| 426 | 
         
            +
                        codebook_loss_r,
         
     | 
| 427 | 
         
            +
                    ) = self.residual_quantizer(residual_feature, 3)
         
     | 
| 428 | 
         
            +
             
     | 
| 429 | 
         
            +
                    bsz = z_r.shape[0]
         
     | 
| 430 | 
         
            +
                    res_mask = np.random.choice(
         
     | 
| 431 | 
         
            +
                        [0, 1],
         
     | 
| 432 | 
         
            +
                        size=bsz,
         
     | 
| 433 | 
         
            +
                        p=[
         
     | 
| 434 | 
         
            +
                            self.prob_random_mask_residual,
         
     | 
| 435 | 
         
            +
                            1 - self.prob_random_mask_residual,
         
     | 
| 436 | 
         
            +
                        ],
         
     | 
| 437 | 
         
            +
                    )
         
     | 
| 438 | 
         
            +
                    res_mask = torch.from_numpy(res_mask).unsqueeze(1).unsqueeze(1)  # (B, 1, 1)
         
     | 
| 439 | 
         
            +
                    res_mask = res_mask.to(device=z_r.device, dtype=z_r.dtype)
         
     | 
| 440 | 
         
            +
                    noise_must_on = noise_added_flags * recon_noisy_flags
         
     | 
| 441 | 
         
            +
                    noise_must_off = noise_added_flags * (~recon_noisy_flags)
         
     | 
| 442 | 
         
            +
                    res_mask[noise_must_on] = 1
         
     | 
| 443 | 
         
            +
                    res_mask[noise_must_off] = 0
         
     | 
| 444 | 
         
            +
             
     | 
| 445 | 
         
            +
                    outs += z_r * res_mask
         
     | 
| 446 | 
         
            +
             
     | 
| 447 | 
         
            +
                    quantized = [z_p, z_c, z_t, z_r]
         
     | 
| 448 | 
         
            +
                    commitment_losses = (
         
     | 
| 449 | 
         
            +
                        commitment_loss_p
         
     | 
| 450 | 
         
            +
                        + commitment_loss_c
         
     | 
| 451 | 
         
            +
                        + commitment_loss_t
         
     | 
| 452 | 
         
            +
                        + commitment_loss_r
         
     | 
| 453 | 
         
            +
                    )
         
     | 
| 454 | 
         
            +
                    codebook_losses = (
         
     | 
| 455 | 
         
            +
                        codebook_loss_p + codebook_loss_c + codebook_loss_t + codebook_loss_r
         
     | 
| 456 | 
         
            +
                    )
         
     | 
| 457 | 
         
            +
             
     | 
| 458 | 
         
            +
                    return outs, quantized, commitment_losses, codebook_losses
         
     | 
| 459 | 
         
            +
             
     | 
| 460 | 
         
            +
                def forward_v2(
         
     | 
| 461 | 
         
            +
                    self,
         
     | 
| 462 | 
         
            +
                    x,
         
     | 
| 463 | 
         
            +
                    wave_segments,
         
     | 
| 464 | 
         
            +
                    n_c=1,
         
     | 
| 465 | 
         
            +
                    n_t=2,
         
     | 
| 466 | 
         
            +
                    full_waves=None,
         
     | 
| 467 | 
         
            +
                    wave_lens=None,
         
     | 
| 468 | 
         
            +
                    return_codes=False,
         
     | 
| 469 | 
         
            +
                ):
         
     | 
| 470 | 
         
            +
                    # timbre = self.timbre_encoder(x, sequence_mask(mel_lens, mels.size(-1)).unsqueeze(1))
         
     | 
| 471 | 
         
            +
                    if full_waves is None:
         
     | 
| 472 | 
         
            +
                        mel = self.preprocess(wave_segments, n_bins=80)
         
     | 
| 473 | 
         
            +
                        timbre = self.timbre_encoder(
         
     | 
| 474 | 
         
            +
                            mel, torch.ones(mel.size(0), 1, mel.size(2)).bool().to(mel.device)
         
     | 
| 475 | 
         
            +
                        )
         
     | 
| 476 | 
         
            +
                    else:
         
     | 
| 477 | 
         
            +
                        mel = self.preprocess(full_waves, n_bins=80)
         
     | 
| 478 | 
         
            +
                        timbre = self.timbre_encoder(
         
     | 
| 479 | 
         
            +
                            mel,
         
     | 
| 480 | 
         
            +
                            sequence_mask(wave_lens // self.hop_length, mel.size(-1)).unsqueeze(1),
         
     | 
| 481 | 
         
            +
                        )
         
     | 
| 482 | 
         
            +
                    outs = 0
         
     | 
| 483 | 
         
            +
                    if self.separate_prosody_encoder:
         
     | 
| 484 | 
         
            +
                        prosody_feature = self.preprocess(wave_segments)
         
     | 
| 485 | 
         
            +
             
     | 
| 486 | 
         
            +
                        f0_input = prosody_feature  # (B, T, 20)
         
     | 
| 487 | 
         
            +
                        f0_input = self.melspec_linear(f0_input)
         
     | 
| 488 | 
         
            +
                        f0_input = self.melspec_encoder(
         
     | 
| 489 | 
         
            +
                            f0_input,
         
     | 
| 490 | 
         
            +
                            torch.ones(f0_input.shape[0], 1, f0_input.shape[2])
         
     | 
| 491 | 
         
            +
                            .to(f0_input.device)
         
     | 
| 492 | 
         
            +
                            .bool(),
         
     | 
| 493 | 
         
            +
                        )
         
     | 
| 494 | 
         
            +
                        f0_input = self.melspec_linear2(f0_input)
         
     | 
| 495 | 
         
            +
             
     | 
| 496 | 
         
            +
                        common_min_size = min(f0_input.size(2), x.size(2))
         
     | 
| 497 | 
         
            +
                        f0_input = f0_input[:, :, :common_min_size]
         
     | 
| 498 | 
         
            +
             
     | 
| 499 | 
         
            +
                        x = x[:, :, :common_min_size]
         
     | 
| 500 | 
         
            +
             
     | 
| 501 | 
         
            +
                        (
         
     | 
| 502 | 
         
            +
                            z_p,
         
     | 
| 503 | 
         
            +
                            codes_p,
         
     | 
| 504 | 
         
            +
                            latents_p,
         
     | 
| 505 | 
         
            +
                            commitment_loss_p,
         
     | 
| 506 | 
         
            +
                            codebook_loss_p,
         
     | 
| 507 | 
         
            +
                        ) = self.prosody_quantizer(f0_input, 1)
         
     | 
| 508 | 
         
            +
                        outs += z_p.detach()
         
     | 
| 509 | 
         
            +
                    else:
         
     | 
| 510 | 
         
            +
                        (
         
     | 
| 511 | 
         
            +
                            z_p,
         
     | 
| 512 | 
         
            +
                            codes_p,
         
     | 
| 513 | 
         
            +
                            latents_p,
         
     | 
| 514 | 
         
            +
                            commitment_loss_p,
         
     | 
| 515 | 
         
            +
                            codebook_loss_p,
         
     | 
| 516 | 
         
            +
                        ) = self.prosody_quantizer(x, 1)
         
     | 
| 517 | 
         
            +
                        outs += z_p.detach()
         
     | 
| 518 | 
         
            +
             
     | 
| 519 | 
         
            +
                    (
         
     | 
| 520 | 
         
            +
                        z_c,
         
     | 
| 521 | 
         
            +
                        codes_c,
         
     | 
| 522 | 
         
            +
                        latents_c,
         
     | 
| 523 | 
         
            +
                        commitment_loss_c,
         
     | 
| 524 | 
         
            +
                        codebook_loss_c,
         
     | 
| 525 | 
         
            +
                    ) = self.content_quantizer(x, n_c)
         
     | 
| 526 | 
         
            +
                    outs += z_c.detach()
         
     | 
| 527 | 
         
            +
             
     | 
| 528 | 
         
            +
                    residual_feature = x - z_p.detach() - z_c.detach()
         
     | 
| 529 | 
         
            +
             
     | 
| 530 | 
         
            +
                    (
         
     | 
| 531 | 
         
            +
                        z_r,
         
     | 
| 532 | 
         
            +
                        codes_r,
         
     | 
| 533 | 
         
            +
                        latents_r,
         
     | 
| 534 | 
         
            +
                        commitment_loss_r,
         
     | 
| 535 | 
         
            +
                        codebook_loss_r,
         
     | 
| 536 | 
         
            +
                    ) = self.residual_quantizer(residual_feature, 3)
         
     | 
| 537 | 
         
            +
             
     | 
| 538 | 
         
            +
                    bsz = z_r.shape[0]
         
     | 
| 539 | 
         
            +
                    res_mask = np.random.choice(
         
     | 
| 540 | 
         
            +
                        [0, 1],
         
     | 
| 541 | 
         
            +
                        size=bsz,
         
     | 
| 542 | 
         
            +
                        p=[
         
     | 
| 543 | 
         
            +
                            self.prob_random_mask_residual,
         
     | 
| 544 | 
         
            +
                            1 - self.prob_random_mask_residual,
         
     | 
| 545 | 
         
            +
                        ],
         
     | 
| 546 | 
         
            +
                    )
         
     | 
| 547 | 
         
            +
                    res_mask = torch.from_numpy(res_mask).unsqueeze(1).unsqueeze(1)  # (B, 1, 1)
         
     | 
| 548 | 
         
            +
                    res_mask = res_mask.to(device=z_r.device, dtype=z_r.dtype)
         
     | 
| 549 | 
         
            +
             
     | 
| 550 | 
         
            +
                    if not self.training:
         
     | 
| 551 | 
         
            +
                        res_mask = torch.ones_like(res_mask)
         
     | 
| 552 | 
         
            +
                    outs += z_r * res_mask
         
     | 
| 553 | 
         
            +
             
     | 
| 554 | 
         
            +
                    quantized = [z_p, z_c, z_r]
         
     | 
| 555 | 
         
            +
                    codes = [codes_p, codes_c, codes_r]
         
     | 
| 556 | 
         
            +
                    commitment_losses = commitment_loss_p + commitment_loss_c + commitment_loss_r
         
     | 
| 557 | 
         
            +
                    codebook_losses = codebook_loss_p + codebook_loss_c + codebook_loss_r
         
     | 
| 558 | 
         
            +
             
     | 
| 559 | 
         
            +
                    style = self.timbre_linear(timbre).unsqueeze(2)  # (B, 2d, 1)
         
     | 
| 560 | 
         
            +
                    gamma, beta = style.chunk(2, 1)  # (B, d, 1)
         
     | 
| 561 | 
         
            +
                    outs = outs.transpose(1, 2)
         
     | 
| 562 | 
         
            +
                    outs = self.timbre_norm(outs)
         
     | 
| 563 | 
         
            +
                    outs = outs.transpose(1, 2)
         
     | 
| 564 | 
         
            +
                    outs = outs * gamma + beta
         
     | 
| 565 | 
         
            +
             
     | 
| 566 | 
         
            +
                    if return_codes:
         
     | 
| 567 | 
         
            +
                        return outs, quantized, commitment_losses, codebook_losses, timbre, codes
         
     | 
| 568 | 
         
            +
                    else:
         
     | 
| 569 | 
         
            +
                        return outs, quantized, commitment_losses, codebook_losses, timbre
         
     | 
| 570 | 
         
            +
             
     | 
| 571 | 
         
            +
                def voice_conversion(self, z, ref_wave):
         
     | 
| 572 | 
         
            +
                    ref_mel = self.preprocess(ref_wave, n_bins=80)
         
     | 
| 573 | 
         
            +
                    ref_timbre = self.timbre_encoder(
         
     | 
| 574 | 
         
            +
                        ref_mel,
         
     | 
| 575 | 
         
            +
                        sequence_mask(
         
     | 
| 576 | 
         
            +
                            torch.LongTensor([ref_wave.size(-1)]).to(z.device) // self.hop_length,
         
     | 
| 577 | 
         
            +
                            ref_mel.size(-1),
         
     | 
| 578 | 
         
            +
                        ).unsqueeze(1),
         
     | 
| 579 | 
         
            +
                    )
         
     | 
| 580 | 
         
            +
                    style = self.timbre_linear(ref_timbre).unsqueeze(2)  # (B, 2d, 1)
         
     | 
| 581 | 
         
            +
                    gamma, beta = style.chunk(2, 1)  # (B, d, 1)
         
     | 
| 582 | 
         
            +
                    outs = z.transpose(1, 2)
         
     | 
| 583 | 
         
            +
                    outs = self.timbre_norm(outs)
         
     | 
| 584 | 
         
            +
                    outs = outs.transpose(1, 2)
         
     | 
| 585 | 
         
            +
                    outs = outs * gamma + beta
         
     | 
| 586 | 
         
            +
             
     | 
| 587 | 
         
            +
                    return outs
         
     | 
| 588 | 
         
            +
             
     | 
| 589 | 
         
            +
             
     | 
| 590 | 
         
            +
            class FApredictors(nn.Module):
         
     | 
| 591 | 
         
            +
                def __init__(
         
     | 
| 592 | 
         
            +
                    self,
         
     | 
| 593 | 
         
            +
                    in_dim=1024,
         
     | 
| 594 | 
         
            +
                    use_gr_content_f0=False,
         
     | 
| 595 | 
         
            +
                    use_gr_prosody_phone=False,
         
     | 
| 596 | 
         
            +
                    use_gr_residual_f0=False,
         
     | 
| 597 | 
         
            +
                    use_gr_residual_phone=False,
         
     | 
| 598 | 
         
            +
                    use_gr_timbre_content=True,
         
     | 
| 599 | 
         
            +
                    use_gr_timbre_prosody=True,
         
     | 
| 600 | 
         
            +
                    use_gr_x_timbre=False,
         
     | 
| 601 | 
         
            +
                    norm_f0=True,
         
     | 
| 602 | 
         
            +
                    timbre_norm=False,
         
     | 
| 603 | 
         
            +
                    use_gr_content_global_f0=False,
         
     | 
| 604 | 
         
            +
                ):
         
     | 
| 605 | 
         
            +
                    super(FApredictors, self).__init__()
         
     | 
| 606 | 
         
            +
                    self.f0_predictor = CNNLSTM(in_dim, 1, 2)
         
     | 
| 607 | 
         
            +
                    self.phone_predictor = CNNLSTM(in_dim, 1024, 1)
         
     | 
| 608 | 
         
            +
                    if timbre_norm:
         
     | 
| 609 | 
         
            +
                        self.timbre_predictor = nn.Linear(in_dim, 20000)
         
     | 
| 610 | 
         
            +
                    else:
         
     | 
| 611 | 
         
            +
                        self.timbre_predictor = CNNLSTM(in_dim, 20000, 1, global_pred=True)
         
     | 
| 612 | 
         
            +
             
     | 
| 613 | 
         
            +
                    self.use_gr_content_f0 = use_gr_content_f0
         
     | 
| 614 | 
         
            +
                    self.use_gr_prosody_phone = use_gr_prosody_phone
         
     | 
| 615 | 
         
            +
                    self.use_gr_residual_f0 = use_gr_residual_f0
         
     | 
| 616 | 
         
            +
                    self.use_gr_residual_phone = use_gr_residual_phone
         
     | 
| 617 | 
         
            +
                    self.use_gr_timbre_content = use_gr_timbre_content
         
     | 
| 618 | 
         
            +
                    self.use_gr_timbre_prosody = use_gr_timbre_prosody
         
     | 
| 619 | 
         
            +
                    self.use_gr_x_timbre = use_gr_x_timbre
         
     | 
| 620 | 
         
            +
             
     | 
| 621 | 
         
            +
                    self.rev_f0_predictor = nn.Sequential(
         
     | 
| 622 | 
         
            +
                        GradientReversal(alpha=1.0), CNNLSTM(in_dim, 1, 2)
         
     | 
| 623 | 
         
            +
                    )
         
     | 
| 624 | 
         
            +
                    self.rev_content_predictor = nn.Sequential(
         
     | 
| 625 | 
         
            +
                        GradientReversal(alpha=1.0), CNNLSTM(in_dim, 1024, 1)
         
     | 
| 626 | 
         
            +
                    )
         
     | 
| 627 | 
         
            +
                    self.rev_timbre_predictor = nn.Sequential(
         
     | 
| 628 | 
         
            +
                        GradientReversal(alpha=1.0), CNNLSTM(in_dim, 20000, 1, global_pred=True)
         
     | 
| 629 | 
         
            +
                    )
         
     | 
| 630 | 
         
            +
             
     | 
| 631 | 
         
            +
                    self.norm_f0 = norm_f0
         
     | 
| 632 | 
         
            +
                    self.timbre_norm = timbre_norm
         
     | 
| 633 | 
         
            +
                    if timbre_norm:
         
     | 
| 634 | 
         
            +
                        self.forward = self.forward_v2
         
     | 
| 635 | 
         
            +
                        self.global_f0_predictor = nn.Linear(in_dim, 1)
         
     | 
| 636 | 
         
            +
             
     | 
| 637 | 
         
            +
                    self.use_gr_content_global_f0 = use_gr_content_global_f0
         
     | 
| 638 | 
         
            +
                    if use_gr_content_global_f0:
         
     | 
| 639 | 
         
            +
                        self.rev_global_f0_predictor = nn.Sequential(
         
     | 
| 640 | 
         
            +
                            GradientReversal(alpha=1.0), CNNLSTM(in_dim, 1, 1, global_pred=True)
         
     | 
| 641 | 
         
            +
                        )
         
     | 
| 642 | 
         
            +
             
     | 
| 643 | 
         
            +
                def forward(self, quantized):
         
     | 
| 644 | 
         
            +
                    prosody_latent = quantized[0]
         
     | 
| 645 | 
         
            +
                    content_latent = quantized[1]
         
     | 
| 646 | 
         
            +
                    timbre_latent = quantized[2]
         
     | 
| 647 | 
         
            +
                    residual_latent = quantized[3]
         
     | 
| 648 | 
         
            +
                    content_pred = self.phone_predictor(content_latent)[0]
         
     | 
| 649 | 
         
            +
             
     | 
| 650 | 
         
            +
                    if self.norm_f0:
         
     | 
| 651 | 
         
            +
                        spk_pred = self.timbre_predictor(timbre_latent)[0]
         
     | 
| 652 | 
         
            +
                        f0_pred, uv_pred = self.f0_predictor(prosody_latent)
         
     | 
| 653 | 
         
            +
                    else:
         
     | 
| 654 | 
         
            +
                        spk_pred = self.timbre_predictor(timbre_latent + prosody_latent)[0]
         
     | 
| 655 | 
         
            +
                        f0_pred, uv_pred = self.f0_predictor(prosody_latent + timbre_latent)
         
     | 
| 656 | 
         
            +
             
     | 
| 657 | 
         
            +
                    prosody_rev_latent = torch.zeros_like(quantized[0])
         
     | 
| 658 | 
         
            +
                    if self.use_gr_content_f0:
         
     | 
| 659 | 
         
            +
                        prosody_rev_latent += quantized[1]
         
     | 
| 660 | 
         
            +
                    if self.use_gr_timbre_prosody:
         
     | 
| 661 | 
         
            +
                        prosody_rev_latent += quantized[2]
         
     | 
| 662 | 
         
            +
                    if self.use_gr_residual_f0:
         
     | 
| 663 | 
         
            +
                        prosody_rev_latent += quantized[3]
         
     | 
| 664 | 
         
            +
                    rev_f0_pred, rev_uv_pred = self.rev_f0_predictor(prosody_rev_latent)
         
     | 
| 665 | 
         
            +
             
     | 
| 666 | 
         
            +
                    content_rev_latent = torch.zeros_like(quantized[1])
         
     | 
| 667 | 
         
            +
                    if self.use_gr_prosody_phone:
         
     | 
| 668 | 
         
            +
                        content_rev_latent += quantized[0]
         
     | 
| 669 | 
         
            +
                    if self.use_gr_timbre_content:
         
     | 
| 670 | 
         
            +
                        content_rev_latent += quantized[2]
         
     | 
| 671 | 
         
            +
                    if self.use_gr_residual_phone:
         
     | 
| 672 | 
         
            +
                        content_rev_latent += quantized[3]
         
     | 
| 673 | 
         
            +
                    rev_content_pred = self.rev_content_predictor(content_rev_latent)[0]
         
     | 
| 674 | 
         
            +
             
     | 
| 675 | 
         
            +
                    if self.norm_f0:
         
     | 
| 676 | 
         
            +
                        timbre_rev_latent = quantized[0] + quantized[1] + quantized[3]
         
     | 
| 677 | 
         
            +
                    else:
         
     | 
| 678 | 
         
            +
                        timbre_rev_latent = quantized[1] + quantized[3]
         
     | 
| 679 | 
         
            +
                    if self.use_gr_x_timbre:
         
     | 
| 680 | 
         
            +
                        x_spk_pred = self.rev_timbre_predictor(timbre_rev_latent)[0]
         
     | 
| 681 | 
         
            +
                    else:
         
     | 
| 682 | 
         
            +
                        x_spk_pred = None
         
     | 
| 683 | 
         
            +
             
     | 
| 684 | 
         
            +
                    preds = {
         
     | 
| 685 | 
         
            +
                        "f0": f0_pred,
         
     | 
| 686 | 
         
            +
                        "uv": uv_pred,
         
     | 
| 687 | 
         
            +
                        "content": content_pred,
         
     | 
| 688 | 
         
            +
                        "timbre": spk_pred,
         
     | 
| 689 | 
         
            +
                    }
         
     | 
| 690 | 
         
            +
             
     | 
| 691 | 
         
            +
                    rev_preds = {
         
     | 
| 692 | 
         
            +
                        "rev_f0": rev_f0_pred,
         
     | 
| 693 | 
         
            +
                        "rev_uv": rev_uv_pred,
         
     | 
| 694 | 
         
            +
                        "rev_content": rev_content_pred,
         
     | 
| 695 | 
         
            +
                        "x_timbre": x_spk_pred,
         
     | 
| 696 | 
         
            +
                    }
         
     | 
| 697 | 
         
            +
                    return preds, rev_preds
         
     | 
| 698 | 
         
            +
             
     | 
| 699 | 
         
            +
                def forward_v2(self, quantized, timbre):
         
     | 
| 700 | 
         
            +
                    prosody_latent = quantized[0]
         
     | 
| 701 | 
         
            +
                    content_latent = quantized[1]
         
     | 
| 702 | 
         
            +
                    residual_latent = quantized[2]
         
     | 
| 703 | 
         
            +
                    content_pred = self.phone_predictor(content_latent)[0]
         
     | 
| 704 | 
         
            +
             
     | 
| 705 | 
         
            +
                    spk_pred = self.timbre_predictor(timbre)
         
     | 
| 706 | 
         
            +
                    f0_pred, uv_pred = self.f0_predictor(prosody_latent)
         
     | 
| 707 | 
         
            +
             
     | 
| 708 | 
         
            +
                    prosody_rev_latent = torch.zeros_like(prosody_latent)
         
     | 
| 709 | 
         
            +
                    if self.use_gr_content_f0:
         
     | 
| 710 | 
         
            +
                        prosody_rev_latent += content_latent
         
     | 
| 711 | 
         
            +
                    if self.use_gr_residual_f0:
         
     | 
| 712 | 
         
            +
                        prosody_rev_latent += residual_latent
         
     | 
| 713 | 
         
            +
                    rev_f0_pred, rev_uv_pred = self.rev_f0_predictor(prosody_rev_latent)
         
     | 
| 714 | 
         
            +
             
     | 
| 715 | 
         
            +
                    content_rev_latent = torch.zeros_like(content_latent)
         
     | 
| 716 | 
         
            +
                    if self.use_gr_prosody_phone:
         
     | 
| 717 | 
         
            +
                        content_rev_latent += prosody_latent
         
     | 
| 718 | 
         
            +
                    if self.use_gr_residual_phone:
         
     | 
| 719 | 
         
            +
                        content_rev_latent += residual_latent
         
     | 
| 720 | 
         
            +
                    rev_content_pred = self.rev_content_predictor(content_rev_latent)[0]
         
     | 
| 721 | 
         
            +
             
     | 
| 722 | 
         
            +
                    timbre_rev_latent = prosody_latent + content_latent + residual_latent
         
     | 
| 723 | 
         
            +
                    if self.use_gr_x_timbre:
         
     | 
| 724 | 
         
            +
                        x_spk_pred = self.rev_timbre_predictor(timbre_rev_latent)[0]
         
     | 
| 725 | 
         
            +
                    else:
         
     | 
| 726 | 
         
            +
                        x_spk_pred = None
         
     | 
| 727 | 
         
            +
             
     | 
| 728 | 
         
            +
                    preds = {
         
     | 
| 729 | 
         
            +
                        "f0": f0_pred,
         
     | 
| 730 | 
         
            +
                        "uv": uv_pred,
         
     | 
| 731 | 
         
            +
                        "content": content_pred,
         
     | 
| 732 | 
         
            +
                        "timbre": spk_pred,
         
     | 
| 733 | 
         
            +
                    }
         
     | 
| 734 | 
         
            +
             
     | 
| 735 | 
         
            +
                    rev_preds = {
         
     | 
| 736 | 
         
            +
                        "rev_f0": rev_f0_pred,
         
     | 
| 737 | 
         
            +
                        "rev_uv": rev_uv_pred,
         
     | 
| 738 | 
         
            +
                        "rev_content": rev_content_pred,
         
     | 
| 739 | 
         
            +
                        "x_timbre": x_spk_pred,
         
     | 
| 740 | 
         
            +
                    }
         
     | 
| 741 | 
         
            +
                    return preds, rev_preds
         
     | 
    	
        models/codec/facodec/modules/style_encoder.py
    ADDED
    
    | 
         @@ -0,0 +1,110 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2023 Amphion.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            # This code is modified from https://github.com/sh-lee-prml/HierSpeechpp/blob/main/ttv_v1/styleencoder.py
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            from . import attentions
         
     | 
| 9 | 
         
            +
            from torch import nn
         
     | 
| 10 | 
         
            +
            import torch
         
     | 
| 11 | 
         
            +
            from torch.nn import functional as F
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            class Mish(nn.Module):
         
     | 
| 15 | 
         
            +
                def __init__(self):
         
     | 
| 16 | 
         
            +
                    super(Mish, self).__init__()
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
                def forward(self, x):
         
     | 
| 19 | 
         
            +
                    return x * torch.tanh(F.softplus(x))
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            class Conv1dGLU(nn.Module):
         
     | 
| 23 | 
         
            +
                """
         
     | 
| 24 | 
         
            +
                Conv1d + GLU(Gated Linear Unit) with residual connection.
         
     | 
| 25 | 
         
            +
                For GLU refer to https://arxiv.org/abs/1612.08083 paper.
         
     | 
| 26 | 
         
            +
                """
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
                def __init__(self, in_channels, out_channels, kernel_size, dropout):
         
     | 
| 29 | 
         
            +
                    super(Conv1dGLU, self).__init__()
         
     | 
| 30 | 
         
            +
                    self.out_channels = out_channels
         
     | 
| 31 | 
         
            +
                    self.conv1 = nn.Conv1d(
         
     | 
| 32 | 
         
            +
                        in_channels, 2 * out_channels, kernel_size=kernel_size, padding=2
         
     | 
| 33 | 
         
            +
                    )
         
     | 
| 34 | 
         
            +
                    self.dropout = nn.Dropout(dropout)
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
                def forward(self, x):
         
     | 
| 37 | 
         
            +
                    residual = x
         
     | 
| 38 | 
         
            +
                    x = self.conv1(x)
         
     | 
| 39 | 
         
            +
                    x1, x2 = torch.split(x, split_size_or_sections=self.out_channels, dim=1)
         
     | 
| 40 | 
         
            +
                    x = x1 * torch.sigmoid(x2)
         
     | 
| 41 | 
         
            +
                    x = residual + self.dropout(x)
         
     | 
| 42 | 
         
            +
                    return x
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
            class StyleEncoder(torch.nn.Module):
         
     | 
| 46 | 
         
            +
                def __init__(self, in_dim=513, hidden_dim=128, out_dim=256):
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                    super().__init__()
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                    self.in_dim = in_dim  # Linear 513 wav2vec 2.0 1024
         
     | 
| 51 | 
         
            +
                    self.hidden_dim = hidden_dim
         
     | 
| 52 | 
         
            +
                    self.out_dim = out_dim
         
     | 
| 53 | 
         
            +
                    self.kernel_size = 5
         
     | 
| 54 | 
         
            +
                    self.n_head = 2
         
     | 
| 55 | 
         
            +
                    self.dropout = 0.1
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                    self.spectral = nn.Sequential(
         
     | 
| 58 | 
         
            +
                        nn.Conv1d(self.in_dim, self.hidden_dim, 1),
         
     | 
| 59 | 
         
            +
                        Mish(),
         
     | 
| 60 | 
         
            +
                        nn.Dropout(self.dropout),
         
     | 
| 61 | 
         
            +
                        nn.Conv1d(self.hidden_dim, self.hidden_dim, 1),
         
     | 
| 62 | 
         
            +
                        Mish(),
         
     | 
| 63 | 
         
            +
                        nn.Dropout(self.dropout),
         
     | 
| 64 | 
         
            +
                    )
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                    self.temporal = nn.Sequential(
         
     | 
| 67 | 
         
            +
                        Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
         
     | 
| 68 | 
         
            +
                        Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
         
     | 
| 69 | 
         
            +
                    )
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                    self.slf_attn = attentions.MultiHeadAttention(
         
     | 
| 72 | 
         
            +
                        self.hidden_dim,
         
     | 
| 73 | 
         
            +
                        self.hidden_dim,
         
     | 
| 74 | 
         
            +
                        self.n_head,
         
     | 
| 75 | 
         
            +
                        p_dropout=self.dropout,
         
     | 
| 76 | 
         
            +
                        proximal_bias=False,
         
     | 
| 77 | 
         
            +
                        proximal_init=True,
         
     | 
| 78 | 
         
            +
                    )
         
     | 
| 79 | 
         
            +
                    self.atten_drop = nn.Dropout(self.dropout)
         
     | 
| 80 | 
         
            +
                    self.fc = nn.Conv1d(self.hidden_dim, self.out_dim, 1)
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                def forward(self, x, mask=None):
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
                    # spectral
         
     | 
| 85 | 
         
            +
                    x = self.spectral(x) * mask
         
     | 
| 86 | 
         
            +
                    # temporal
         
     | 
| 87 | 
         
            +
                    x = self.temporal(x) * mask
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                    # self-attention
         
     | 
| 90 | 
         
            +
                    attn_mask = mask.unsqueeze(2) * mask.unsqueeze(-1)
         
     | 
| 91 | 
         
            +
                    y = self.slf_attn(x, x, attn_mask=attn_mask)
         
     | 
| 92 | 
         
            +
                    x = x + self.atten_drop(y)
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                    # fc
         
     | 
| 95 | 
         
            +
                    x = self.fc(x)
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                    # temoral average pooling
         
     | 
| 98 | 
         
            +
                    w = self.temporal_avg_pool(x, mask=mask)
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                    return w
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                def temporal_avg_pool(self, x, mask=None):
         
     | 
| 103 | 
         
            +
                    if mask is None:
         
     | 
| 104 | 
         
            +
                        out = torch.mean(x, dim=2)
         
     | 
| 105 | 
         
            +
                    else:
         
     | 
| 106 | 
         
            +
                        len_ = mask.sum(dim=2)
         
     | 
| 107 | 
         
            +
                        x = x.sum(dim=2)
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
                        out = torch.div(x, len_)
         
     | 
| 110 | 
         
            +
                    return out
         
     | 
    	
        models/codec/facodec/modules/wavenet.py
    ADDED
    
    | 
         @@ -0,0 +1,224 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2023 Amphion.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            # This code is modified from https://github.com/sh-lee-prml/HierSpeechpp/blob/main/ttv_v1/modules.py
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            import math
         
     | 
| 9 | 
         
            +
            import torch
         
     | 
| 10 | 
         
            +
            from torch import nn
         
     | 
| 11 | 
         
            +
            from torch.nn import functional as F
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            from modules.dac.model.encodec import SConv1d
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            from . import commons
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            LRELU_SLOPE = 0.1
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            class LayerNorm(nn.Module):
         
     | 
| 21 | 
         
            +
                def __init__(self, channels, eps=1e-5):
         
     | 
| 22 | 
         
            +
                    super().__init__()
         
     | 
| 23 | 
         
            +
                    self.channels = channels
         
     | 
| 24 | 
         
            +
                    self.eps = eps
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
                    self.gamma = nn.Parameter(torch.ones(channels))
         
     | 
| 27 | 
         
            +
                    self.beta = nn.Parameter(torch.zeros(channels))
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                def forward(self, x):
         
     | 
| 30 | 
         
            +
                    x = x.transpose(1, -1)
         
     | 
| 31 | 
         
            +
                    x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
         
     | 
| 32 | 
         
            +
                    return x.transpose(1, -1)
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            class ConvReluNorm(nn.Module):
         
     | 
| 36 | 
         
            +
                def __init__(
         
     | 
| 37 | 
         
            +
                    self,
         
     | 
| 38 | 
         
            +
                    in_channels,
         
     | 
| 39 | 
         
            +
                    hidden_channels,
         
     | 
| 40 | 
         
            +
                    out_channels,
         
     | 
| 41 | 
         
            +
                    kernel_size,
         
     | 
| 42 | 
         
            +
                    n_layers,
         
     | 
| 43 | 
         
            +
                    p_dropout,
         
     | 
| 44 | 
         
            +
                ):
         
     | 
| 45 | 
         
            +
                    super().__init__()
         
     | 
| 46 | 
         
            +
                    self.in_channels = in_channels
         
     | 
| 47 | 
         
            +
                    self.hidden_channels = hidden_channels
         
     | 
| 48 | 
         
            +
                    self.out_channels = out_channels
         
     | 
| 49 | 
         
            +
                    self.kernel_size = kernel_size
         
     | 
| 50 | 
         
            +
                    self.n_layers = n_layers
         
     | 
| 51 | 
         
            +
                    self.p_dropout = p_dropout
         
     | 
| 52 | 
         
            +
                    assert n_layers > 1, "Number of layers should be larger than 0."
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                    self.conv_layers = nn.ModuleList()
         
     | 
| 55 | 
         
            +
                    self.norm_layers = nn.ModuleList()
         
     | 
| 56 | 
         
            +
                    self.conv_layers.append(
         
     | 
| 57 | 
         
            +
                        nn.Conv1d(
         
     | 
| 58 | 
         
            +
                            in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
         
     | 
| 59 | 
         
            +
                        )
         
     | 
| 60 | 
         
            +
                    )
         
     | 
| 61 | 
         
            +
                    self.norm_layers.append(LayerNorm(hidden_channels))
         
     | 
| 62 | 
         
            +
                    self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
         
     | 
| 63 | 
         
            +
                    for _ in range(n_layers - 1):
         
     | 
| 64 | 
         
            +
                        self.conv_layers.append(
         
     | 
| 65 | 
         
            +
                            nn.Conv1d(
         
     | 
| 66 | 
         
            +
                                hidden_channels,
         
     | 
| 67 | 
         
            +
                                hidden_channels,
         
     | 
| 68 | 
         
            +
                                kernel_size,
         
     | 
| 69 | 
         
            +
                                padding=kernel_size // 2,
         
     | 
| 70 | 
         
            +
                            )
         
     | 
| 71 | 
         
            +
                        )
         
     | 
| 72 | 
         
            +
                        self.norm_layers.append(LayerNorm(hidden_channels))
         
     | 
| 73 | 
         
            +
                    self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
         
     | 
| 74 | 
         
            +
                    self.proj.weight.data.zero_()
         
     | 
| 75 | 
         
            +
                    self.proj.bias.data.zero_()
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
                def forward(self, x, x_mask):
         
     | 
| 78 | 
         
            +
                    x_org = x
         
     | 
| 79 | 
         
            +
                    for i in range(self.n_layers):
         
     | 
| 80 | 
         
            +
                        x = self.conv_layers[i](x * x_mask)
         
     | 
| 81 | 
         
            +
                        x = self.norm_layers[i](x)
         
     | 
| 82 | 
         
            +
                        x = self.relu_drop(x)
         
     | 
| 83 | 
         
            +
                    x = x_org + self.proj(x)
         
     | 
| 84 | 
         
            +
                    return x * x_mask
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
            class DDSConv(nn.Module):
         
     | 
| 88 | 
         
            +
                """
         
     | 
| 89 | 
         
            +
                Dialted and Depth-Separable Convolution
         
     | 
| 90 | 
         
            +
                """
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
         
     | 
| 93 | 
         
            +
                    super().__init__()
         
     | 
| 94 | 
         
            +
                    self.channels = channels
         
     | 
| 95 | 
         
            +
                    self.kernel_size = kernel_size
         
     | 
| 96 | 
         
            +
                    self.n_layers = n_layers
         
     | 
| 97 | 
         
            +
                    self.p_dropout = p_dropout
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
                    self.drop = nn.Dropout(p_dropout)
         
     | 
| 100 | 
         
            +
                    self.convs_sep = nn.ModuleList()
         
     | 
| 101 | 
         
            +
                    self.convs_1x1 = nn.ModuleList()
         
     | 
| 102 | 
         
            +
                    self.norms_1 = nn.ModuleList()
         
     | 
| 103 | 
         
            +
                    self.norms_2 = nn.ModuleList()
         
     | 
| 104 | 
         
            +
                    for i in range(n_layers):
         
     | 
| 105 | 
         
            +
                        dilation = kernel_size**i
         
     | 
| 106 | 
         
            +
                        padding = (kernel_size * dilation - dilation) // 2
         
     | 
| 107 | 
         
            +
                        self.convs_sep.append(
         
     | 
| 108 | 
         
            +
                            nn.Conv1d(
         
     | 
| 109 | 
         
            +
                                channels,
         
     | 
| 110 | 
         
            +
                                channels,
         
     | 
| 111 | 
         
            +
                                kernel_size,
         
     | 
| 112 | 
         
            +
                                groups=channels,
         
     | 
| 113 | 
         
            +
                                dilation=dilation,
         
     | 
| 114 | 
         
            +
                                padding=padding,
         
     | 
| 115 | 
         
            +
                            )
         
     | 
| 116 | 
         
            +
                        )
         
     | 
| 117 | 
         
            +
                        self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
         
     | 
| 118 | 
         
            +
                        self.norms_1.append(LayerNorm(channels))
         
     | 
| 119 | 
         
            +
                        self.norms_2.append(LayerNorm(channels))
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
                def forward(self, x, x_mask, g=None):
         
     | 
| 122 | 
         
            +
                    if g is not None:
         
     | 
| 123 | 
         
            +
                        x = x + g
         
     | 
| 124 | 
         
            +
                    for i in range(self.n_layers):
         
     | 
| 125 | 
         
            +
                        y = self.convs_sep[i](x * x_mask)
         
     | 
| 126 | 
         
            +
                        y = self.norms_1[i](y)
         
     | 
| 127 | 
         
            +
                        y = F.gelu(y)
         
     | 
| 128 | 
         
            +
                        y = self.convs_1x1[i](y)
         
     | 
| 129 | 
         
            +
                        y = self.norms_2[i](y)
         
     | 
| 130 | 
         
            +
                        y = F.gelu(y)
         
     | 
| 131 | 
         
            +
                        y = self.drop(y)
         
     | 
| 132 | 
         
            +
                        x = x + y
         
     | 
| 133 | 
         
            +
                    return x * x_mask
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
            class WN(torch.nn.Module):
         
     | 
| 137 | 
         
            +
                def __init__(
         
     | 
| 138 | 
         
            +
                    self,
         
     | 
| 139 | 
         
            +
                    hidden_channels,
         
     | 
| 140 | 
         
            +
                    kernel_size,
         
     | 
| 141 | 
         
            +
                    dilation_rate,
         
     | 
| 142 | 
         
            +
                    n_layers,
         
     | 
| 143 | 
         
            +
                    gin_channels=0,
         
     | 
| 144 | 
         
            +
                    p_dropout=0,
         
     | 
| 145 | 
         
            +
                    causal=False,
         
     | 
| 146 | 
         
            +
                ):
         
     | 
| 147 | 
         
            +
                    super(WN, self).__init__()
         
     | 
| 148 | 
         
            +
                    conv1d_type = SConv1d
         
     | 
| 149 | 
         
            +
                    assert kernel_size % 2 == 1
         
     | 
| 150 | 
         
            +
                    self.hidden_channels = hidden_channels
         
     | 
| 151 | 
         
            +
                    self.kernel_size = (kernel_size,)
         
     | 
| 152 | 
         
            +
                    self.dilation_rate = dilation_rate
         
     | 
| 153 | 
         
            +
                    self.n_layers = n_layers
         
     | 
| 154 | 
         
            +
                    self.gin_channels = gin_channels
         
     | 
| 155 | 
         
            +
                    self.p_dropout = p_dropout
         
     | 
| 156 | 
         
            +
             
     | 
| 157 | 
         
            +
                    self.in_layers = torch.nn.ModuleList()
         
     | 
| 158 | 
         
            +
                    self.res_skip_layers = torch.nn.ModuleList()
         
     | 
| 159 | 
         
            +
                    self.drop = nn.Dropout(p_dropout)
         
     | 
| 160 | 
         
            +
             
     | 
| 161 | 
         
            +
                    if gin_channels != 0:
         
     | 
| 162 | 
         
            +
                        self.cond_layer = conv1d_type(
         
     | 
| 163 | 
         
            +
                            gin_channels, 2 * hidden_channels * n_layers, 1, norm="weight_norm"
         
     | 
| 164 | 
         
            +
                        )
         
     | 
| 165 | 
         
            +
             
     | 
| 166 | 
         
            +
                    for i in range(n_layers):
         
     | 
| 167 | 
         
            +
                        dilation = dilation_rate**i
         
     | 
| 168 | 
         
            +
                        padding = int((kernel_size * dilation - dilation) / 2)
         
     | 
| 169 | 
         
            +
                        in_layer = conv1d_type(
         
     | 
| 170 | 
         
            +
                            hidden_channels,
         
     | 
| 171 | 
         
            +
                            2 * hidden_channels,
         
     | 
| 172 | 
         
            +
                            kernel_size,
         
     | 
| 173 | 
         
            +
                            dilation=dilation,
         
     | 
| 174 | 
         
            +
                            padding=padding,
         
     | 
| 175 | 
         
            +
                            norm="weight_norm",
         
     | 
| 176 | 
         
            +
                            causal=causal,
         
     | 
| 177 | 
         
            +
                        )
         
     | 
| 178 | 
         
            +
                        self.in_layers.append(in_layer)
         
     | 
| 179 | 
         
            +
             
     | 
| 180 | 
         
            +
                        # last one is not necessary
         
     | 
| 181 | 
         
            +
                        if i < n_layers - 1:
         
     | 
| 182 | 
         
            +
                            res_skip_channels = 2 * hidden_channels
         
     | 
| 183 | 
         
            +
                        else:
         
     | 
| 184 | 
         
            +
                            res_skip_channels = hidden_channels
         
     | 
| 185 | 
         
            +
             
     | 
| 186 | 
         
            +
                        res_skip_layer = conv1d_type(
         
     | 
| 187 | 
         
            +
                            hidden_channels, res_skip_channels, 1, norm="weight_norm", causal=causal
         
     | 
| 188 | 
         
            +
                        )
         
     | 
| 189 | 
         
            +
                        self.res_skip_layers.append(res_skip_layer)
         
     | 
| 190 | 
         
            +
             
     | 
| 191 | 
         
            +
                def forward(self, x, x_mask, g=None, **kwargs):
         
     | 
| 192 | 
         
            +
                    output = torch.zeros_like(x)
         
     | 
| 193 | 
         
            +
                    n_channels_tensor = torch.IntTensor([self.hidden_channels])
         
     | 
| 194 | 
         
            +
             
     | 
| 195 | 
         
            +
                    if g is not None:
         
     | 
| 196 | 
         
            +
                        g = self.cond_layer(g)
         
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
            +
                    for i in range(self.n_layers):
         
     | 
| 199 | 
         
            +
                        x_in = self.in_layers[i](x)
         
     | 
| 200 | 
         
            +
                        if g is not None:
         
     | 
| 201 | 
         
            +
                            cond_offset = i * 2 * self.hidden_channels
         
     | 
| 202 | 
         
            +
                            g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
         
     | 
| 203 | 
         
            +
                        else:
         
     | 
| 204 | 
         
            +
                            g_l = torch.zeros_like(x_in)
         
     | 
| 205 | 
         
            +
             
     | 
| 206 | 
         
            +
                        acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
         
     | 
| 207 | 
         
            +
                        acts = self.drop(acts)
         
     | 
| 208 | 
         
            +
             
     | 
| 209 | 
         
            +
                        res_skip_acts = self.res_skip_layers[i](acts)
         
     | 
| 210 | 
         
            +
                        if i < self.n_layers - 1:
         
     | 
| 211 | 
         
            +
                            res_acts = res_skip_acts[:, : self.hidden_channels, :]
         
     | 
| 212 | 
         
            +
                            x = (x + res_acts) * x_mask
         
     | 
| 213 | 
         
            +
                            output = output + res_skip_acts[:, self.hidden_channels :, :]
         
     | 
| 214 | 
         
            +
                        else:
         
     | 
| 215 | 
         
            +
                            output = output + res_skip_acts
         
     | 
| 216 | 
         
            +
                    return output * x_mask
         
     | 
| 217 | 
         
            +
             
     | 
| 218 | 
         
            +
                def remove_weight_norm(self):
         
     | 
| 219 | 
         
            +
                    if self.gin_channels != 0:
         
     | 
| 220 | 
         
            +
                        torch.nn.utils.remove_weight_norm(self.cond_layer)
         
     | 
| 221 | 
         
            +
                    for l in self.in_layers:
         
     | 
| 222 | 
         
            +
                        torch.nn.utils.remove_weight_norm(l)
         
     | 
| 223 | 
         
            +
                    for l in self.res_skip_layers:
         
     | 
| 224 | 
         
            +
                        torch.nn.utils.remove_weight_norm(l)
         
     | 
    	
        models/codec/facodec/optimizer.py
    ADDED
    
    | 
         @@ -0,0 +1,104 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2023 Amphion.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import os, sys
         
     | 
| 7 | 
         
            +
            import os.path as osp
         
     | 
| 8 | 
         
            +
            import numpy as np
         
     | 
| 9 | 
         
            +
            import torch
         
     | 
| 10 | 
         
            +
            from torch import nn
         
     | 
| 11 | 
         
            +
            from torch.optim import Optimizer
         
     | 
| 12 | 
         
            +
            from functools import reduce
         
     | 
| 13 | 
         
            +
            from torch.optim import AdamW
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            class MultiOptimizer:
         
     | 
| 17 | 
         
            +
                def __init__(self, optimizers={}, schedulers={}):
         
     | 
| 18 | 
         
            +
                    self.optimizers = optimizers
         
     | 
| 19 | 
         
            +
                    self.schedulers = schedulers
         
     | 
| 20 | 
         
            +
                    self.keys = list(optimizers.keys())
         
     | 
| 21 | 
         
            +
                    self.param_groups = reduce(
         
     | 
| 22 | 
         
            +
                        lambda x, y: x + y, [v.param_groups for v in self.optimizers.values()]
         
     | 
| 23 | 
         
            +
                    )
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
                def state_dict(self):
         
     | 
| 26 | 
         
            +
                    state_dicts = [(key, self.optimizers[key].state_dict()) for key in self.keys]
         
     | 
| 27 | 
         
            +
                    return state_dicts
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                def scheduler_state_dict(self):
         
     | 
| 30 | 
         
            +
                    state_dicts = [(key, self.schedulers[key].state_dict()) for key in self.keys]
         
     | 
| 31 | 
         
            +
                    return state_dicts
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
                def load_state_dict(self, state_dict):
         
     | 
| 34 | 
         
            +
                    for key, val in state_dict:
         
     | 
| 35 | 
         
            +
                        try:
         
     | 
| 36 | 
         
            +
                            self.optimizers[key].load_state_dict(val)
         
     | 
| 37 | 
         
            +
                        except:
         
     | 
| 38 | 
         
            +
                            print("Unloaded %s" % key)
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                def load_scheduler_state_dict(self, state_dict):
         
     | 
| 41 | 
         
            +
                    for key, val in state_dict:
         
     | 
| 42 | 
         
            +
                        try:
         
     | 
| 43 | 
         
            +
                            self.schedulers[key].load_state_dict(val)
         
     | 
| 44 | 
         
            +
                        except:
         
     | 
| 45 | 
         
            +
                            print("Unloaded %s" % key)
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                def step(self, key=None, scaler=None):
         
     | 
| 48 | 
         
            +
                    keys = [key] if key is not None else self.keys
         
     | 
| 49 | 
         
            +
                    _ = [self._step(key, scaler) for key in keys]
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                def _step(self, key, scaler=None):
         
     | 
| 52 | 
         
            +
                    if scaler is not None:
         
     | 
| 53 | 
         
            +
                        scaler.step(self.optimizers[key])
         
     | 
| 54 | 
         
            +
                        scaler.update()
         
     | 
| 55 | 
         
            +
                    else:
         
     | 
| 56 | 
         
            +
                        self.optimizers[key].step()
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                def zero_grad(self, key=None):
         
     | 
| 59 | 
         
            +
                    if key is not None:
         
     | 
| 60 | 
         
            +
                        self.optimizers[key].zero_grad()
         
     | 
| 61 | 
         
            +
                    else:
         
     | 
| 62 | 
         
            +
                        _ = [self.optimizers[key].zero_grad() for key in self.keys]
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                def scheduler(self, *args, key=None):
         
     | 
| 65 | 
         
            +
                    if key is not None:
         
     | 
| 66 | 
         
            +
                        self.schedulers[key].step(*args)
         
     | 
| 67 | 
         
            +
                    else:
         
     | 
| 68 | 
         
            +
                        _ = [self.schedulers[key].step_batch(*args) for key in self.keys]
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
            def define_scheduler(optimizer, params):
         
     | 
| 72 | 
         
            +
                scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=params["gamma"])
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                return scheduler
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
            def build_optimizer(model_dict, scheduler_params_dict, lr, type="AdamW"):
         
     | 
| 78 | 
         
            +
                optim = {}
         
     | 
| 79 | 
         
            +
                for key, model in model_dict.items():
         
     | 
| 80 | 
         
            +
                    model_parameters = model.parameters()
         
     | 
| 81 | 
         
            +
                    parameters_names = []
         
     | 
| 82 | 
         
            +
                    parameters_names.append(
         
     | 
| 83 | 
         
            +
                        [name_param_pair[0] for name_param_pair in model.named_parameters()]
         
     | 
| 84 | 
         
            +
                    )
         
     | 
| 85 | 
         
            +
                    if type == "AdamW":
         
     | 
| 86 | 
         
            +
                        optim[key] = AdamW(
         
     | 
| 87 | 
         
            +
                            model_parameters,
         
     | 
| 88 | 
         
            +
                            lr=lr,
         
     | 
| 89 | 
         
            +
                            betas=(0.9, 0.98),
         
     | 
| 90 | 
         
            +
                            eps=1e-9,
         
     | 
| 91 | 
         
            +
                            weight_decay=0.1,
         
     | 
| 92 | 
         
            +
                        )
         
     | 
| 93 | 
         
            +
                    else:
         
     | 
| 94 | 
         
            +
                        raise ValueError("Unknown optimizer type: %s" % type)
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
                schedulers = dict(
         
     | 
| 97 | 
         
            +
                    [
         
     | 
| 98 | 
         
            +
                        (key, torch.optim.lr_scheduler.ExponentialLR(opt, gamma=0.999996))
         
     | 
| 99 | 
         
            +
                        for key, opt in optim.items()
         
     | 
| 100 | 
         
            +
                    ]
         
     | 
| 101 | 
         
            +
                )
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                multi_optim = MultiOptimizer(optim, schedulers)
         
     | 
| 104 | 
         
            +
                return multi_optim
         
     | 
    	
        models/codec/kmeans/repcodec_model.py
    ADDED
    
    | 
         @@ -0,0 +1,210 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2024 Amphion.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            from concurrent.futures import ALL_COMPLETED
         
     | 
| 7 | 
         
            +
            import numpy as np
         
     | 
| 8 | 
         
            +
            import torch
         
     | 
| 9 | 
         
            +
            import torch.nn as nn
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            from torch.nn import functional as F
         
     | 
| 12 | 
         
            +
            from einops import rearrange, repeat
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            from models.codec.amphion_codec.quantize import ResidualVQ
         
     | 
| 15 | 
         
            +
            from models.codec.kmeans.vocos import VocosBackbone
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            def init_weights(m):
         
     | 
| 19 | 
         
            +
                if isinstance(m, nn.Conv1d):
         
     | 
| 20 | 
         
            +
                    nn.init.trunc_normal_(m.weight, std=0.02)
         
     | 
| 21 | 
         
            +
                    nn.init.constant_(m.bias, 0)
         
     | 
| 22 | 
         
            +
                if isinstance(m, nn.Linear):
         
     | 
| 23 | 
         
            +
                    nn.init.trunc_normal_(m.weight, std=0.02)
         
     | 
| 24 | 
         
            +
                    nn.init.constant_(m.bias, 0)
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            def compute_codebook_perplexity(indices, codebook_size):
         
     | 
| 28 | 
         
            +
                indices = indices.flatten()
         
     | 
| 29 | 
         
            +
                prob = torch.bincount(indices, minlength=codebook_size).float() / indices.size(0)
         
     | 
| 30 | 
         
            +
                perp = torch.exp(-torch.sum(prob * torch.log(prob + 1e-10)))
         
     | 
| 31 | 
         
            +
                return perp
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
            class RepCodec(nn.Module):
         
     | 
| 35 | 
         
            +
                def __init__(
         
     | 
| 36 | 
         
            +
                    self,
         
     | 
| 37 | 
         
            +
                    codebook_size=8192,
         
     | 
| 38 | 
         
            +
                    hidden_size=1024,
         
     | 
| 39 | 
         
            +
                    codebook_dim=8,
         
     | 
| 40 | 
         
            +
                    vocos_dim=384,
         
     | 
| 41 | 
         
            +
                    vocos_intermediate_dim=2048,
         
     | 
| 42 | 
         
            +
                    vocos_num_layers=12,
         
     | 
| 43 | 
         
            +
                    num_quantizers=1,
         
     | 
| 44 | 
         
            +
                    downsample_scale=1,
         
     | 
| 45 | 
         
            +
                    cfg=None,
         
     | 
| 46 | 
         
            +
                ):
         
     | 
| 47 | 
         
            +
                    super().__init__()
         
     | 
| 48 | 
         
            +
                    codebook_size = (
         
     | 
| 49 | 
         
            +
                        cfg.codebook_size
         
     | 
| 50 | 
         
            +
                        if cfg is not None and hasattr(cfg, "codebook_size")
         
     | 
| 51 | 
         
            +
                        else codebook_size
         
     | 
| 52 | 
         
            +
                    )
         
     | 
| 53 | 
         
            +
                    codebook_dim = (
         
     | 
| 54 | 
         
            +
                        cfg.codebook_dim
         
     | 
| 55 | 
         
            +
                        if cfg is not None and hasattr(cfg, "codebook_dim")
         
     | 
| 56 | 
         
            +
                        else codebook_dim
         
     | 
| 57 | 
         
            +
                    )
         
     | 
| 58 | 
         
            +
                    hidden_size = (
         
     | 
| 59 | 
         
            +
                        cfg.hidden_size
         
     | 
| 60 | 
         
            +
                        if cfg is not None and hasattr(cfg, "hidden_size")
         
     | 
| 61 | 
         
            +
                        else hidden_size
         
     | 
| 62 | 
         
            +
                    )
         
     | 
| 63 | 
         
            +
                    vocos_dim = (
         
     | 
| 64 | 
         
            +
                        cfg.vocos_dim
         
     | 
| 65 | 
         
            +
                        if cfg is not None and hasattr(cfg, "vocos_dim")
         
     | 
| 66 | 
         
            +
                        else vocos_dim
         
     | 
| 67 | 
         
            +
                    )
         
     | 
| 68 | 
         
            +
                    vocos_intermediate_dim = (
         
     | 
| 69 | 
         
            +
                        cfg.vocos_intermediate_dim
         
     | 
| 70 | 
         
            +
                        if cfg is not None and hasattr(cfg, "vocos_dim")
         
     | 
| 71 | 
         
            +
                        else vocos_intermediate_dim
         
     | 
| 72 | 
         
            +
                    )
         
     | 
| 73 | 
         
            +
                    vocos_num_layers = (
         
     | 
| 74 | 
         
            +
                        cfg.vocos_num_layers
         
     | 
| 75 | 
         
            +
                        if cfg is not None and hasattr(cfg, "vocos_dim")
         
     | 
| 76 | 
         
            +
                        else vocos_num_layers
         
     | 
| 77 | 
         
            +
                    )
         
     | 
| 78 | 
         
            +
                    num_quantizers = (
         
     | 
| 79 | 
         
            +
                        cfg.num_quantizers
         
     | 
| 80 | 
         
            +
                        if cfg is not None and hasattr(cfg, "num_quantizers")
         
     | 
| 81 | 
         
            +
                        else num_quantizers
         
     | 
| 82 | 
         
            +
                    )
         
     | 
| 83 | 
         
            +
                    downsample_scale = (
         
     | 
| 84 | 
         
            +
                        cfg.downsample_scale
         
     | 
| 85 | 
         
            +
                        if cfg is not None and hasattr(cfg, "downsample_scale")
         
     | 
| 86 | 
         
            +
                        else downsample_scale
         
     | 
| 87 | 
         
            +
                    )
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                    self.codebook_size = codebook_size
         
     | 
| 90 | 
         
            +
                    self.codebook_dim = codebook_dim
         
     | 
| 91 | 
         
            +
                    self.hidden_size = hidden_size
         
     | 
| 92 | 
         
            +
                    self.vocos_dim = vocos_dim
         
     | 
| 93 | 
         
            +
                    self.vocos_intermediate_dim = vocos_intermediate_dim
         
     | 
| 94 | 
         
            +
                    self.vocos_num_layers = vocos_num_layers
         
     | 
| 95 | 
         
            +
                    self.num_quantizers = num_quantizers
         
     | 
| 96 | 
         
            +
                    self.downsample_scale = downsample_scale
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
                    if self.downsample_scale != None and self.downsample_scale > 1:
         
     | 
| 99 | 
         
            +
                        self.down = nn.Conv1d(
         
     | 
| 100 | 
         
            +
                            self.hidden_size, self.hidden_size, kernel_size=3, stride=2, padding=1
         
     | 
| 101 | 
         
            +
                        )
         
     | 
| 102 | 
         
            +
                        self.up = nn.Conv1d(
         
     | 
| 103 | 
         
            +
                            self.hidden_size, self.hidden_size, kernel_size=3, stride=1, padding=1
         
     | 
| 104 | 
         
            +
                        )
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
                    self.encoder = nn.Sequential(
         
     | 
| 107 | 
         
            +
                        VocosBackbone(
         
     | 
| 108 | 
         
            +
                            input_channels=self.hidden_size,
         
     | 
| 109 | 
         
            +
                            dim=self.vocos_dim,
         
     | 
| 110 | 
         
            +
                            intermediate_dim=self.vocos_intermediate_dim,
         
     | 
| 111 | 
         
            +
                            num_layers=self.vocos_num_layers,
         
     | 
| 112 | 
         
            +
                            adanorm_num_embeddings=None,
         
     | 
| 113 | 
         
            +
                        ),
         
     | 
| 114 | 
         
            +
                        nn.Linear(self.vocos_dim, self.hidden_size),
         
     | 
| 115 | 
         
            +
                    )
         
     | 
| 116 | 
         
            +
                    self.decoder = nn.Sequential(
         
     | 
| 117 | 
         
            +
                        VocosBackbone(
         
     | 
| 118 | 
         
            +
                            input_channels=self.hidden_size,
         
     | 
| 119 | 
         
            +
                            dim=self.vocos_dim,
         
     | 
| 120 | 
         
            +
                            intermediate_dim=self.vocos_intermediate_dim,
         
     | 
| 121 | 
         
            +
                            num_layers=self.vocos_num_layers,
         
     | 
| 122 | 
         
            +
                            adanorm_num_embeddings=None,
         
     | 
| 123 | 
         
            +
                        ),
         
     | 
| 124 | 
         
            +
                        nn.Linear(self.vocos_dim, self.hidden_size),
         
     | 
| 125 | 
         
            +
                    )
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                    self.quantizer = ResidualVQ(
         
     | 
| 128 | 
         
            +
                        input_dim=hidden_size,
         
     | 
| 129 | 
         
            +
                        num_quantizers=num_quantizers,
         
     | 
| 130 | 
         
            +
                        codebook_size=codebook_size,
         
     | 
| 131 | 
         
            +
                        codebook_dim=codebook_dim,
         
     | 
| 132 | 
         
            +
                        quantizer_type="fvq",
         
     | 
| 133 | 
         
            +
                        quantizer_dropout=0.0,
         
     | 
| 134 | 
         
            +
                        commitment=0.15,
         
     | 
| 135 | 
         
            +
                        codebook_loss_weight=1.0,
         
     | 
| 136 | 
         
            +
                        use_l2_normlize=True,
         
     | 
| 137 | 
         
            +
                    )
         
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
                    self.reset_parameters()
         
     | 
| 140 | 
         
            +
             
     | 
| 141 | 
         
            +
                def forward(self, x):
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
                    # downsample
         
     | 
| 144 | 
         
            +
                    if self.downsample_scale != None and self.downsample_scale > 1:
         
     | 
| 145 | 
         
            +
                        x = x.transpose(1, 2)
         
     | 
| 146 | 
         
            +
                        x = self.down(x)
         
     | 
| 147 | 
         
            +
                        x = F.gelu(x)
         
     | 
| 148 | 
         
            +
                        x = x.transpose(1, 2)
         
     | 
| 149 | 
         
            +
             
     | 
| 150 | 
         
            +
                    # encoder
         
     | 
| 151 | 
         
            +
                    x = self.encoder(x.transpose(1, 2)).transpose(1, 2)
         
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
                    # vq
         
     | 
| 154 | 
         
            +
                    (
         
     | 
| 155 | 
         
            +
                        quantized_out,
         
     | 
| 156 | 
         
            +
                        all_indices,
         
     | 
| 157 | 
         
            +
                        all_commit_losses,
         
     | 
| 158 | 
         
            +
                        all_codebook_losses,
         
     | 
| 159 | 
         
            +
                        _,
         
     | 
| 160 | 
         
            +
                    ) = self.quantizer(x)
         
     | 
| 161 | 
         
            +
             
     | 
| 162 | 
         
            +
                    # decoder
         
     | 
| 163 | 
         
            +
                    x = self.decoder(quantized_out)
         
     | 
| 164 | 
         
            +
             
     | 
| 165 | 
         
            +
                    # up
         
     | 
| 166 | 
         
            +
                    if self.downsample_scale != None and self.downsample_scale > 1:
         
     | 
| 167 | 
         
            +
                        x = x.transpose(1, 2)
         
     | 
| 168 | 
         
            +
                        x = F.interpolate(x, scale_factor=2, mode="nearest")
         
     | 
| 169 | 
         
            +
                        x_rec = self.up(x).transpose(1, 2)
         
     | 
| 170 | 
         
            +
             
     | 
| 171 | 
         
            +
                    codebook_loss = (all_codebook_losses + all_commit_losses).mean()
         
     | 
| 172 | 
         
            +
                    all_indices = all_indices
         
     | 
| 173 | 
         
            +
             
     | 
| 174 | 
         
            +
                    return x_rec, codebook_loss, all_indices
         
     | 
| 175 | 
         
            +
             
     | 
| 176 | 
         
            +
                def quantize(self, x):
         
     | 
| 177 | 
         
            +
             
     | 
| 178 | 
         
            +
                    if self.downsample_scale != None and self.downsample_scale > 1:
         
     | 
| 179 | 
         
            +
                        x = x.transpose(1, 2)
         
     | 
| 180 | 
         
            +
                        x = self.down(x)
         
     | 
| 181 | 
         
            +
                        x = F.gelu(x)
         
     | 
| 182 | 
         
            +
                        x = x.transpose(1, 2)
         
     | 
| 183 | 
         
            +
             
     | 
| 184 | 
         
            +
                    x = self.encoder(x.transpose(1, 2)).transpose(1, 2)
         
     | 
| 185 | 
         
            +
             
     | 
| 186 | 
         
            +
                    (
         
     | 
| 187 | 
         
            +
                        quantized_out,
         
     | 
| 188 | 
         
            +
                        all_indices,
         
     | 
| 189 | 
         
            +
                        all_commit_losses,
         
     | 
| 190 | 
         
            +
                        all_codebook_losses,
         
     | 
| 191 | 
         
            +
                        _,
         
     | 
| 192 | 
         
            +
                    ) = self.quantizer(x)
         
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
                    if all_indices.shape[0] == 1:
         
     | 
| 195 | 
         
            +
                        return all_indices.squeeze(0), quantized_out.transpose(1, 2)
         
     | 
| 196 | 
         
            +
                    return all_indices, quantized_out.transpose(1, 2)
         
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
            +
                def reset_parameters(self):
         
     | 
| 199 | 
         
            +
                    self.apply(init_weights)
         
     | 
| 200 | 
         
            +
             
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 203 | 
         
            +
                repcodec = RepCodec(vocos_dim=1024, downsample_scale=2)
         
     | 
| 204 | 
         
            +
                print(repcodec)
         
     | 
| 205 | 
         
            +
                print(sum(p.numel() for p in repcodec.parameters()) / 1e6)
         
     | 
| 206 | 
         
            +
                x = torch.randn(5, 10, 1024)
         
     | 
| 207 | 
         
            +
                x_rec, codebook_loss, all_indices = repcodec(x)
         
     | 
| 208 | 
         
            +
                print(x_rec.shape, codebook_loss, all_indices.shape)
         
     | 
| 209 | 
         
            +
                vq_id, emb = repcodec.quantize(x)
         
     | 
| 210 | 
         
            +
                print(vq_id.shape, emb.shape)
         
     | 
    	
        models/codec/kmeans/vocos.py
    ADDED
    
    | 
         @@ -0,0 +1,850 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2024 Amphion.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            from typing import Optional, Tuple
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            import numpy as np
         
     | 
| 9 | 
         
            +
            import scipy
         
     | 
| 10 | 
         
            +
            import torch
         
     | 
| 11 | 
         
            +
            from torch import nn, view_as_real, view_as_complex
         
     | 
| 12 | 
         
            +
            from torch import nn
         
     | 
| 13 | 
         
            +
            from torch.nn.utils import weight_norm, remove_weight_norm
         
     | 
| 14 | 
         
            +
            from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
         
     | 
| 18 | 
         
            +
                """
         
     | 
| 19 | 
         
            +
                Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
                Args:
         
     | 
| 22 | 
         
            +
                    x (Tensor): Input tensor.
         
     | 
| 23 | 
         
            +
                    clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
                Returns:
         
     | 
| 26 | 
         
            +
                    Tensor: Element-wise logarithm of the input tensor with clipping applied.
         
     | 
| 27 | 
         
            +
                """
         
     | 
| 28 | 
         
            +
                return torch.log(torch.clip(x, min=clip_val))
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
            def symlog(x: torch.Tensor) -> torch.Tensor:
         
     | 
| 32 | 
         
            +
                return torch.sign(x) * torch.log1p(x.abs())
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            def symexp(x: torch.Tensor) -> torch.Tensor:
         
     | 
| 36 | 
         
            +
                return torch.sign(x) * (torch.exp(x.abs()) - 1)
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
            class STFT(nn.Module):
         
     | 
| 40 | 
         
            +
                def __init__(
         
     | 
| 41 | 
         
            +
                    self,
         
     | 
| 42 | 
         
            +
                    n_fft: int,
         
     | 
| 43 | 
         
            +
                    hop_length: int,
         
     | 
| 44 | 
         
            +
                    win_length: int,
         
     | 
| 45 | 
         
            +
                    center=True,
         
     | 
| 46 | 
         
            +
                ):
         
     | 
| 47 | 
         
            +
                    super().__init__()
         
     | 
| 48 | 
         
            +
                    self.center = center
         
     | 
| 49 | 
         
            +
                    self.n_fft = n_fft
         
     | 
| 50 | 
         
            +
                    self.hop_length = hop_length
         
     | 
| 51 | 
         
            +
                    self.win_length = win_length
         
     | 
| 52 | 
         
            +
                    window = torch.hann_window(win_length)
         
     | 
| 53 | 
         
            +
                    self.register_buffer("window", window)
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         
     | 
| 56 | 
         
            +
                    # x: (B, T * hop_length)
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                    if not self.center:
         
     | 
| 59 | 
         
            +
                        pad = self.win_length - self.hop_length
         
     | 
| 60 | 
         
            +
                        x = torch.nn.functional.pad(x, (pad // 2, pad // 2), mode="reflect")
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                    stft_spec = torch.stft(
         
     | 
| 63 | 
         
            +
                        x,
         
     | 
| 64 | 
         
            +
                        self.n_fft,
         
     | 
| 65 | 
         
            +
                        hop_length=self.hop_length,
         
     | 
| 66 | 
         
            +
                        win_length=self.win_length,
         
     | 
| 67 | 
         
            +
                        window=self.window,
         
     | 
| 68 | 
         
            +
                        center=self.center,
         
     | 
| 69 | 
         
            +
                        return_complex=False,
         
     | 
| 70 | 
         
            +
                    )  # (B, n_fft // 2 + 1, T, 2)
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                    rea = stft_spec[:, :, :, 0]  # (B, n_fft // 2 + 1, T, 2)
         
     | 
| 73 | 
         
            +
                    imag = stft_spec[:, :, :, 1]  # (B, n_fft // 2 + 1, T, 2)
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
                    log_mag = torch.log(
         
     | 
| 76 | 
         
            +
                        torch.abs(torch.sqrt(torch.pow(rea, 2) + torch.pow(imag, 2))) + 1e-5
         
     | 
| 77 | 
         
            +
                    )  # (B, n_fft // 2 + 1, T)
         
     | 
| 78 | 
         
            +
                    phase = torch.atan2(imag, rea)  # (B, n_fft // 2 + 1, T)
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
                    return log_mag, phase
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
            class ISTFT(nn.Module):
         
     | 
| 84 | 
         
            +
                """
         
     | 
| 85 | 
         
            +
                Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
         
     | 
| 86 | 
         
            +
                windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
         
     | 
| 87 | 
         
            +
                See issue: https://github.com/pytorch/pytorch/issues/62323
         
     | 
| 88 | 
         
            +
                Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
         
     | 
| 89 | 
         
            +
                The NOLA constraint is met as we trim padded samples anyway.
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                Args:
         
     | 
| 92 | 
         
            +
                    n_fft (int): Size of Fourier transform.
         
     | 
| 93 | 
         
            +
                    hop_length (int): The distance between neighboring sliding window frames.
         
     | 
| 94 | 
         
            +
                    win_length (int): The size of window frame and STFT filter.
         
     | 
| 95 | 
         
            +
                    padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
         
     | 
| 96 | 
         
            +
                """
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
                def __init__(
         
     | 
| 99 | 
         
            +
                    self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"
         
     | 
| 100 | 
         
            +
                ):
         
     | 
| 101 | 
         
            +
                    super().__init__()
         
     | 
| 102 | 
         
            +
                    if padding not in ["center", "same"]:
         
     | 
| 103 | 
         
            +
                        raise ValueError("Padding must be 'center' or 'same'.")
         
     | 
| 104 | 
         
            +
                    self.padding = padding
         
     | 
| 105 | 
         
            +
                    self.n_fft = n_fft
         
     | 
| 106 | 
         
            +
                    self.hop_length = hop_length
         
     | 
| 107 | 
         
            +
                    self.win_length = win_length
         
     | 
| 108 | 
         
            +
                    window = torch.hann_window(win_length)
         
     | 
| 109 | 
         
            +
                    self.register_buffer("window", window)
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                def forward(self, spec: torch.Tensor) -> torch.Tensor:
         
     | 
| 112 | 
         
            +
                    """
         
     | 
| 113 | 
         
            +
                    Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
                    Args:
         
     | 
| 116 | 
         
            +
                        spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
         
     | 
| 117 | 
         
            +
                                        N is the number of frequency bins, and T is the number of time frames.
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                    Returns:
         
     | 
| 120 | 
         
            +
                        Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
         
     | 
| 121 | 
         
            +
                    """
         
     | 
| 122 | 
         
            +
                    if self.padding == "center":
         
     | 
| 123 | 
         
            +
                        # Fallback to pytorch native implementation
         
     | 
| 124 | 
         
            +
                        return torch.istft(
         
     | 
| 125 | 
         
            +
                            spec,
         
     | 
| 126 | 
         
            +
                            self.n_fft,
         
     | 
| 127 | 
         
            +
                            self.hop_length,
         
     | 
| 128 | 
         
            +
                            self.win_length,
         
     | 
| 129 | 
         
            +
                            self.window,
         
     | 
| 130 | 
         
            +
                            center=True,
         
     | 
| 131 | 
         
            +
                        )
         
     | 
| 132 | 
         
            +
                    elif self.padding == "same":
         
     | 
| 133 | 
         
            +
                        pad = (self.win_length - self.hop_length) // 2
         
     | 
| 134 | 
         
            +
                    else:
         
     | 
| 135 | 
         
            +
                        raise ValueError("Padding must be 'center' or 'same'.")
         
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
                    assert spec.dim() == 3, "Expected a 3D tensor as input"
         
     | 
| 138 | 
         
            +
                    B, N, T = spec.shape
         
     | 
| 139 | 
         
            +
             
     | 
| 140 | 
         
            +
                    # Inverse FFT
         
     | 
| 141 | 
         
            +
                    ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
         
     | 
| 142 | 
         
            +
                    ifft = ifft * self.window[None, :, None]
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
                    # Overlap and Add
         
     | 
| 145 | 
         
            +
                    output_size = (T - 1) * self.hop_length + self.win_length
         
     | 
| 146 | 
         
            +
                    y = torch.nn.functional.fold(
         
     | 
| 147 | 
         
            +
                        ifft,
         
     | 
| 148 | 
         
            +
                        output_size=(1, output_size),
         
     | 
| 149 | 
         
            +
                        kernel_size=(1, self.win_length),
         
     | 
| 150 | 
         
            +
                        stride=(1, self.hop_length),
         
     | 
| 151 | 
         
            +
                    )[:, 0, 0, pad:-pad]
         
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
                    # Window envelope
         
     | 
| 154 | 
         
            +
                    window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
         
     | 
| 155 | 
         
            +
                    window_envelope = torch.nn.functional.fold(
         
     | 
| 156 | 
         
            +
                        window_sq,
         
     | 
| 157 | 
         
            +
                        output_size=(1, output_size),
         
     | 
| 158 | 
         
            +
                        kernel_size=(1, self.win_length),
         
     | 
| 159 | 
         
            +
                        stride=(1, self.hop_length),
         
     | 
| 160 | 
         
            +
                    ).squeeze()[pad:-pad]
         
     | 
| 161 | 
         
            +
             
     | 
| 162 | 
         
            +
                    # Normalize
         
     | 
| 163 | 
         
            +
                    assert (window_envelope > 1e-11).all()
         
     | 
| 164 | 
         
            +
                    y = y / window_envelope
         
     | 
| 165 | 
         
            +
             
     | 
| 166 | 
         
            +
                    return y
         
     | 
| 167 | 
         
            +
             
     | 
| 168 | 
         
            +
             
     | 
| 169 | 
         
            +
            class MDCT(nn.Module):
         
     | 
| 170 | 
         
            +
                """
         
     | 
| 171 | 
         
            +
                Modified Discrete Cosine Transform (MDCT) module.
         
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
                Args:
         
     | 
| 174 | 
         
            +
                    frame_len (int): Length of the MDCT frame.
         
     | 
| 175 | 
         
            +
                    padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
         
     | 
| 176 | 
         
            +
                """
         
     | 
| 177 | 
         
            +
             
     | 
| 178 | 
         
            +
                def __init__(self, frame_len: int, padding: str = "same"):
         
     | 
| 179 | 
         
            +
                    super().__init__()
         
     | 
| 180 | 
         
            +
                    if padding not in ["center", "same"]:
         
     | 
| 181 | 
         
            +
                        raise ValueError("Padding must be 'center' or 'same'.")
         
     | 
| 182 | 
         
            +
                    self.padding = padding
         
     | 
| 183 | 
         
            +
                    self.frame_len = frame_len
         
     | 
| 184 | 
         
            +
                    N = frame_len // 2
         
     | 
| 185 | 
         
            +
                    n0 = (N + 1) / 2
         
     | 
| 186 | 
         
            +
                    window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
         
     | 
| 187 | 
         
            +
                    self.register_buffer("window", window)
         
     | 
| 188 | 
         
            +
             
     | 
| 189 | 
         
            +
                    pre_twiddle = torch.exp(-1j * torch.pi * torch.arange(frame_len) / frame_len)
         
     | 
| 190 | 
         
            +
                    post_twiddle = torch.exp(-1j * torch.pi * n0 * (torch.arange(N) + 0.5) / N)
         
     | 
| 191 | 
         
            +
                    # view_as_real: NCCL Backend does not support ComplexFloat data type
         
     | 
| 192 | 
         
            +
                    # https://github.com/pytorch/pytorch/issues/71613
         
     | 
| 193 | 
         
            +
                    self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
         
     | 
| 194 | 
         
            +
                    self.register_buffer("post_twiddle", view_as_real(post_twiddle))
         
     | 
| 195 | 
         
            +
             
     | 
| 196 | 
         
            +
                def forward(self, audio: torch.Tensor) -> torch.Tensor:
         
     | 
| 197 | 
         
            +
                    """
         
     | 
| 198 | 
         
            +
                    Apply the Modified Discrete Cosine Transform (MDCT) to the input audio.
         
     | 
| 199 | 
         
            +
             
     | 
| 200 | 
         
            +
                    Args:
         
     | 
| 201 | 
         
            +
                        audio (Tensor): Input audio waveform of shape (B, T), where B is the batch size
         
     | 
| 202 | 
         
            +
                            and T is the length of the audio.
         
     | 
| 203 | 
         
            +
             
     | 
| 204 | 
         
            +
                    Returns:
         
     | 
| 205 | 
         
            +
                        Tensor: MDCT coefficients of shape (B, L, N), where L is the number of output frames
         
     | 
| 206 | 
         
            +
                            and N is the number of frequency bins.
         
     | 
| 207 | 
         
            +
                    """
         
     | 
| 208 | 
         
            +
                    if self.padding == "center":
         
     | 
| 209 | 
         
            +
                        audio = torch.nn.functional.pad(
         
     | 
| 210 | 
         
            +
                            audio, (self.frame_len // 2, self.frame_len // 2)
         
     | 
| 211 | 
         
            +
                        )
         
     | 
| 212 | 
         
            +
                    elif self.padding == "same":
         
     | 
| 213 | 
         
            +
                        # hop_length is 1/2 frame_len
         
     | 
| 214 | 
         
            +
                        audio = torch.nn.functional.pad(
         
     | 
| 215 | 
         
            +
                            audio, (self.frame_len // 4, self.frame_len // 4)
         
     | 
| 216 | 
         
            +
                        )
         
     | 
| 217 | 
         
            +
                    else:
         
     | 
| 218 | 
         
            +
                        raise ValueError("Padding must be 'center' or 'same'.")
         
     | 
| 219 | 
         
            +
             
     | 
| 220 | 
         
            +
                    x = audio.unfold(-1, self.frame_len, self.frame_len // 2)
         
     | 
| 221 | 
         
            +
                    N = self.frame_len // 2
         
     | 
| 222 | 
         
            +
                    x = x * self.window.expand(x.shape)
         
     | 
| 223 | 
         
            +
                    X = torch.fft.fft(
         
     | 
| 224 | 
         
            +
                        x * view_as_complex(self.pre_twiddle).expand(x.shape), dim=-1
         
     | 
| 225 | 
         
            +
                    )[..., :N]
         
     | 
| 226 | 
         
            +
                    res = X * view_as_complex(self.post_twiddle).expand(X.shape) * np.sqrt(1 / N)
         
     | 
| 227 | 
         
            +
                    return torch.real(res) * np.sqrt(2)
         
     | 
| 228 | 
         
            +
             
     | 
| 229 | 
         
            +
             
     | 
| 230 | 
         
            +
            class IMDCT(nn.Module):
         
     | 
| 231 | 
         
            +
                """
         
     | 
| 232 | 
         
            +
                Inverse Modified Discrete Cosine Transform (IMDCT) module.
         
     | 
| 233 | 
         
            +
             
     | 
| 234 | 
         
            +
                Args:
         
     | 
| 235 | 
         
            +
                    frame_len (int): Length of the MDCT frame.
         
     | 
| 236 | 
         
            +
                    padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
         
     | 
| 237 | 
         
            +
                """
         
     | 
| 238 | 
         
            +
             
     | 
| 239 | 
         
            +
                def __init__(self, frame_len: int, padding: str = "same"):
         
     | 
| 240 | 
         
            +
                    super().__init__()
         
     | 
| 241 | 
         
            +
                    if padding not in ["center", "same"]:
         
     | 
| 242 | 
         
            +
                        raise ValueError("Padding must be 'center' or 'same'.")
         
     | 
| 243 | 
         
            +
                    self.padding = padding
         
     | 
| 244 | 
         
            +
                    self.frame_len = frame_len
         
     | 
| 245 | 
         
            +
                    N = frame_len // 2
         
     | 
| 246 | 
         
            +
                    n0 = (N + 1) / 2
         
     | 
| 247 | 
         
            +
                    window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
         
     | 
| 248 | 
         
            +
                    self.register_buffer("window", window)
         
     | 
| 249 | 
         
            +
             
     | 
| 250 | 
         
            +
                    pre_twiddle = torch.exp(1j * torch.pi * n0 * torch.arange(N * 2) / N)
         
     | 
| 251 | 
         
            +
                    post_twiddle = torch.exp(1j * torch.pi * (torch.arange(N * 2) + n0) / (N * 2))
         
     | 
| 252 | 
         
            +
                    self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
         
     | 
| 253 | 
         
            +
                    self.register_buffer("post_twiddle", view_as_real(post_twiddle))
         
     | 
| 254 | 
         
            +
             
     | 
| 255 | 
         
            +
                def forward(self, X: torch.Tensor) -> torch.Tensor:
         
     | 
| 256 | 
         
            +
                    """
         
     | 
| 257 | 
         
            +
                    Apply the Inverse Modified Discrete Cosine Transform (IMDCT) to the input MDCT coefficients.
         
     | 
| 258 | 
         
            +
             
     | 
| 259 | 
         
            +
                    Args:
         
     | 
| 260 | 
         
            +
                        X (Tensor): Input MDCT coefficients of shape (B, L, N), where B is the batch size,
         
     | 
| 261 | 
         
            +
                            L is the number of frames, and N is the number of frequency bins.
         
     | 
| 262 | 
         
            +
             
     | 
| 263 | 
         
            +
                    Returns:
         
     | 
| 264 | 
         
            +
                        Tensor: Reconstructed audio waveform of shape (B, T), where T is the length of the audio.
         
     | 
| 265 | 
         
            +
                    """
         
     | 
| 266 | 
         
            +
                    B, L, N = X.shape
         
     | 
| 267 | 
         
            +
                    Y = torch.zeros((B, L, N * 2), dtype=X.dtype, device=X.device)
         
     | 
| 268 | 
         
            +
                    Y[..., :N] = X
         
     | 
| 269 | 
         
            +
                    Y[..., N:] = -1 * torch.conj(torch.flip(X, dims=(-1,)))
         
     | 
| 270 | 
         
            +
                    y = torch.fft.ifft(
         
     | 
| 271 | 
         
            +
                        Y * view_as_complex(self.pre_twiddle).expand(Y.shape), dim=-1
         
     | 
| 272 | 
         
            +
                    )
         
     | 
| 273 | 
         
            +
                    y = (
         
     | 
| 274 | 
         
            +
                        torch.real(y * view_as_complex(self.post_twiddle).expand(y.shape))
         
     | 
| 275 | 
         
            +
                        * np.sqrt(N)
         
     | 
| 276 | 
         
            +
                        * np.sqrt(2)
         
     | 
| 277 | 
         
            +
                    )
         
     | 
| 278 | 
         
            +
                    result = y * self.window.expand(y.shape)
         
     | 
| 279 | 
         
            +
                    output_size = (1, (L + 1) * N)
         
     | 
| 280 | 
         
            +
                    audio = torch.nn.functional.fold(
         
     | 
| 281 | 
         
            +
                        result.transpose(1, 2),
         
     | 
| 282 | 
         
            +
                        output_size=output_size,
         
     | 
| 283 | 
         
            +
                        kernel_size=(1, self.frame_len),
         
     | 
| 284 | 
         
            +
                        stride=(1, self.frame_len // 2),
         
     | 
| 285 | 
         
            +
                    )[:, 0, 0, :]
         
     | 
| 286 | 
         
            +
             
     | 
| 287 | 
         
            +
                    if self.padding == "center":
         
     | 
| 288 | 
         
            +
                        pad = self.frame_len // 2
         
     | 
| 289 | 
         
            +
                    elif self.padding == "same":
         
     | 
| 290 | 
         
            +
                        pad = self.frame_len // 4
         
     | 
| 291 | 
         
            +
                    else:
         
     | 
| 292 | 
         
            +
                        raise ValueError("Padding must be 'center' or 'same'.")
         
     | 
| 293 | 
         
            +
             
     | 
| 294 | 
         
            +
                    audio = audio[:, pad:-pad]
         
     | 
| 295 | 
         
            +
                    return audio
         
     | 
| 296 | 
         
            +
             
     | 
| 297 | 
         
            +
             
     | 
| 298 | 
         
            +
            class FourierHead(nn.Module):
         
     | 
| 299 | 
         
            +
                """Base class for inverse fourier modules."""
         
     | 
| 300 | 
         
            +
             
     | 
| 301 | 
         
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         
     | 
| 302 | 
         
            +
                    """
         
     | 
| 303 | 
         
            +
                    Args:
         
     | 
| 304 | 
         
            +
                        x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
         
     | 
| 305 | 
         
            +
                                    L is the sequence length, and H denotes the model dimension.
         
     | 
| 306 | 
         
            +
             
     | 
| 307 | 
         
            +
                    Returns:
         
     | 
| 308 | 
         
            +
                        Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
         
     | 
| 309 | 
         
            +
                    """
         
     | 
| 310 | 
         
            +
                    raise NotImplementedError("Subclasses must implement the forward method.")
         
     | 
| 311 | 
         
            +
             
     | 
| 312 | 
         
            +
             
     | 
| 313 | 
         
            +
            class ISTFTHead(FourierHead):
         
     | 
| 314 | 
         
            +
                """
         
     | 
| 315 | 
         
            +
                ISTFT Head module for predicting STFT complex coefficients.
         
     | 
| 316 | 
         
            +
             
     | 
| 317 | 
         
            +
                Args:
         
     | 
| 318 | 
         
            +
                    dim (int): Hidden dimension of the model.
         
     | 
| 319 | 
         
            +
                    n_fft (int): Size of Fourier transform.
         
     | 
| 320 | 
         
            +
                    hop_length (int): The distance between neighboring sliding window frames, which should align with
         
     | 
| 321 | 
         
            +
                                      the resolution of the input features.
         
     | 
| 322 | 
         
            +
                    padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
         
     | 
| 323 | 
         
            +
                """
         
     | 
| 324 | 
         
            +
             
     | 
| 325 | 
         
            +
                def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
         
     | 
| 326 | 
         
            +
                    super().__init__()
         
     | 
| 327 | 
         
            +
                    out_dim = n_fft + 2
         
     | 
| 328 | 
         
            +
                    self.out = torch.nn.Linear(dim, out_dim)
         
     | 
| 329 | 
         
            +
                    self.istft = ISTFT(
         
     | 
| 330 | 
         
            +
                        n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding
         
     | 
| 331 | 
         
            +
                    )
         
     | 
| 332 | 
         
            +
             
     | 
| 333 | 
         
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         
     | 
| 334 | 
         
            +
                    """
         
     | 
| 335 | 
         
            +
                    Forward pass of the ISTFTHead module.
         
     | 
| 336 | 
         
            +
             
     | 
| 337 | 
         
            +
                    Args:
         
     | 
| 338 | 
         
            +
                        x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
         
     | 
| 339 | 
         
            +
                                    L is the sequence length, and H denotes the model dimension.
         
     | 
| 340 | 
         
            +
             
     | 
| 341 | 
         
            +
                    Returns:
         
     | 
| 342 | 
         
            +
                        Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
         
     | 
| 343 | 
         
            +
                    """
         
     | 
| 344 | 
         
            +
                    x = self.out(x).transpose(1, 2)
         
     | 
| 345 | 
         
            +
                    mag, p = x.chunk(2, dim=1)
         
     | 
| 346 | 
         
            +
                    mag = torch.exp(mag)
         
     | 
| 347 | 
         
            +
                    mag = torch.clip(
         
     | 
| 348 | 
         
            +
                        mag, max=1e2
         
     | 
| 349 | 
         
            +
                    )  # safeguard to prevent excessively large magnitudes
         
     | 
| 350 | 
         
            +
                    # wrapping happens here. These two lines produce real and imaginary value
         
     | 
| 351 | 
         
            +
                    x = torch.cos(p)
         
     | 
| 352 | 
         
            +
                    y = torch.sin(p)
         
     | 
| 353 | 
         
            +
                    # recalculating phase here does not produce anything new
         
     | 
| 354 | 
         
            +
                    # only costs time
         
     | 
| 355 | 
         
            +
                    # phase = torch.atan2(y, x)
         
     | 
| 356 | 
         
            +
                    # S = mag * torch.exp(phase * 1j)
         
     | 
| 357 | 
         
            +
                    # better directly produce the complex value
         
     | 
| 358 | 
         
            +
                    S = mag * (x + 1j * y)
         
     | 
| 359 | 
         
            +
                    audio = self.istft(S)
         
     | 
| 360 | 
         
            +
                    return audio
         
     | 
| 361 | 
         
            +
             
     | 
| 362 | 
         
            +
             
     | 
| 363 | 
         
            +
            class IMDCTSymExpHead(FourierHead):
         
     | 
| 364 | 
         
            +
                """
         
     | 
| 365 | 
         
            +
                IMDCT Head module for predicting MDCT coefficients with symmetric exponential function
         
     | 
| 366 | 
         
            +
             
     | 
| 367 | 
         
            +
                Args:
         
     | 
| 368 | 
         
            +
                    dim (int): Hidden dimension of the model.
         
     | 
| 369 | 
         
            +
                    mdct_frame_len (int): Length of the MDCT frame.
         
     | 
| 370 | 
         
            +
                    padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
         
     | 
| 371 | 
         
            +
                    sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized
         
     | 
| 372 | 
         
            +
                                                 based on perceptual scaling. Defaults to None.
         
     | 
| 373 | 
         
            +
                    clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
         
     | 
| 374 | 
         
            +
                """
         
     | 
| 375 | 
         
            +
             
     | 
| 376 | 
         
            +
                def __init__(
         
     | 
| 377 | 
         
            +
                    self,
         
     | 
| 378 | 
         
            +
                    dim: int,
         
     | 
| 379 | 
         
            +
                    mdct_frame_len: int,
         
     | 
| 380 | 
         
            +
                    padding: str = "same",
         
     | 
| 381 | 
         
            +
                    sample_rate: Optional[int] = None,
         
     | 
| 382 | 
         
            +
                    clip_audio: bool = False,
         
     | 
| 383 | 
         
            +
                ):
         
     | 
| 384 | 
         
            +
                    super().__init__()
         
     | 
| 385 | 
         
            +
                    out_dim = mdct_frame_len // 2
         
     | 
| 386 | 
         
            +
                    self.out = nn.Linear(dim, out_dim)
         
     | 
| 387 | 
         
            +
                    self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
         
     | 
| 388 | 
         
            +
                    self.clip_audio = clip_audio
         
     | 
| 389 | 
         
            +
             
     | 
| 390 | 
         
            +
                    if sample_rate is not None:
         
     | 
| 391 | 
         
            +
                        # optionally init the last layer following mel-scale
         
     | 
| 392 | 
         
            +
                        m_max = _hz_to_mel(sample_rate // 2)
         
     | 
| 393 | 
         
            +
                        m_pts = torch.linspace(0, m_max, out_dim)
         
     | 
| 394 | 
         
            +
                        f_pts = _mel_to_hz(m_pts)
         
     | 
| 395 | 
         
            +
                        scale = 1 - (f_pts / f_pts.max())
         
     | 
| 396 | 
         
            +
             
     | 
| 397 | 
         
            +
                        with torch.no_grad():
         
     | 
| 398 | 
         
            +
                            self.out.weight.mul_(scale.view(-1, 1))
         
     | 
| 399 | 
         
            +
             
     | 
| 400 | 
         
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         
     | 
| 401 | 
         
            +
                    """
         
     | 
| 402 | 
         
            +
                    Forward pass of the IMDCTSymExpHead module.
         
     | 
| 403 | 
         
            +
             
     | 
| 404 | 
         
            +
                    Args:
         
     | 
| 405 | 
         
            +
                        x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
         
     | 
| 406 | 
         
            +
                                    L is the sequence length, and H denotes the model dimension.
         
     | 
| 407 | 
         
            +
             
     | 
| 408 | 
         
            +
                    Returns:
         
     | 
| 409 | 
         
            +
                        Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
         
     | 
| 410 | 
         
            +
                    """
         
     | 
| 411 | 
         
            +
                    x = self.out(x)
         
     | 
| 412 | 
         
            +
                    x = symexp(x)
         
     | 
| 413 | 
         
            +
                    x = torch.clip(
         
     | 
| 414 | 
         
            +
                        x, min=-1e2, max=1e2
         
     | 
| 415 | 
         
            +
                    )  # safeguard to prevent excessively large magnitudes
         
     | 
| 416 | 
         
            +
                    audio = self.imdct(x)
         
     | 
| 417 | 
         
            +
                    if self.clip_audio:
         
     | 
| 418 | 
         
            +
                        audio = torch.clip(x, min=-1.0, max=1.0)
         
     | 
| 419 | 
         
            +
             
     | 
| 420 | 
         
            +
                    return audio
         
     | 
| 421 | 
         
            +
             
     | 
| 422 | 
         
            +
             
     | 
| 423 | 
         
            +
            class IMDCTCosHead(FourierHead):
         
     | 
| 424 | 
         
            +
                """
         
     | 
| 425 | 
         
            +
                IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) · cos(p)
         
     | 
| 426 | 
         
            +
             
     | 
| 427 | 
         
            +
                Args:
         
     | 
| 428 | 
         
            +
                    dim (int): Hidden dimension of the model.
         
     | 
| 429 | 
         
            +
                    mdct_frame_len (int): Length of the MDCT frame.
         
     | 
| 430 | 
         
            +
                    padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
         
     | 
| 431 | 
         
            +
                    clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
         
     | 
| 432 | 
         
            +
                """
         
     | 
| 433 | 
         
            +
             
     | 
| 434 | 
         
            +
                def __init__(
         
     | 
| 435 | 
         
            +
                    self,
         
     | 
| 436 | 
         
            +
                    dim: int,
         
     | 
| 437 | 
         
            +
                    mdct_frame_len: int,
         
     | 
| 438 | 
         
            +
                    padding: str = "same",
         
     | 
| 439 | 
         
            +
                    clip_audio: bool = False,
         
     | 
| 440 | 
         
            +
                ):
         
     | 
| 441 | 
         
            +
                    super().__init__()
         
     | 
| 442 | 
         
            +
                    self.clip_audio = clip_audio
         
     | 
| 443 | 
         
            +
                    self.out = nn.Linear(dim, mdct_frame_len)
         
     | 
| 444 | 
         
            +
                    self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
         
     | 
| 445 | 
         
            +
             
     | 
| 446 | 
         
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         
     | 
| 447 | 
         
            +
                    """
         
     | 
| 448 | 
         
            +
                    Forward pass of the IMDCTCosHead module.
         
     | 
| 449 | 
         
            +
             
     | 
| 450 | 
         
            +
                    Args:
         
     | 
| 451 | 
         
            +
                        x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
         
     | 
| 452 | 
         
            +
                                    L is the sequence length, and H denotes the model dimension.
         
     | 
| 453 | 
         
            +
             
     | 
| 454 | 
         
            +
                    Returns:
         
     | 
| 455 | 
         
            +
                        Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
         
     | 
| 456 | 
         
            +
                    """
         
     | 
| 457 | 
         
            +
                    x = self.out(x)
         
     | 
| 458 | 
         
            +
                    m, p = x.chunk(2, dim=2)
         
     | 
| 459 | 
         
            +
                    m = torch.exp(m).clip(
         
     | 
| 460 | 
         
            +
                        max=1e2
         
     | 
| 461 | 
         
            +
                    )  # safeguard to prevent excessively large magnitudes
         
     | 
| 462 | 
         
            +
                    audio = self.imdct(m * torch.cos(p))
         
     | 
| 463 | 
         
            +
                    if self.clip_audio:
         
     | 
| 464 | 
         
            +
                        audio = torch.clip(x, min=-1.0, max=1.0)
         
     | 
| 465 | 
         
            +
                    return audio
         
     | 
| 466 | 
         
            +
             
     | 
| 467 | 
         
            +
             
     | 
| 468 | 
         
            +
            class ConvNeXtBlock(nn.Module):
         
     | 
| 469 | 
         
            +
                """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
         
     | 
| 470 | 
         
            +
             
     | 
| 471 | 
         
            +
                Args:
         
     | 
| 472 | 
         
            +
                    dim (int): Number of input channels.
         
     | 
| 473 | 
         
            +
                    intermediate_dim (int): Dimensionality of the intermediate layer.
         
     | 
| 474 | 
         
            +
                    layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
         
     | 
| 475 | 
         
            +
                        Defaults to None.
         
     | 
| 476 | 
         
            +
                    adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
         
     | 
| 477 | 
         
            +
                        None means non-conditional LayerNorm. Defaults to None.
         
     | 
| 478 | 
         
            +
                """
         
     | 
| 479 | 
         
            +
             
     | 
| 480 | 
         
            +
                def __init__(
         
     | 
| 481 | 
         
            +
                    self,
         
     | 
| 482 | 
         
            +
                    dim: int,
         
     | 
| 483 | 
         
            +
                    intermediate_dim: int,
         
     | 
| 484 | 
         
            +
                    layer_scale_init_value: float,
         
     | 
| 485 | 
         
            +
                    adanorm_num_embeddings: Optional[int] = None,
         
     | 
| 486 | 
         
            +
                ):
         
     | 
| 487 | 
         
            +
                    super().__init__()
         
     | 
| 488 | 
         
            +
                    self.dwconv = nn.Conv1d(
         
     | 
| 489 | 
         
            +
                        dim, dim, kernel_size=7, padding=3, groups=dim
         
     | 
| 490 | 
         
            +
                    )  # depthwise conv
         
     | 
| 491 | 
         
            +
                    self.adanorm = adanorm_num_embeddings is not None
         
     | 
| 492 | 
         
            +
                    if adanorm_num_embeddings:
         
     | 
| 493 | 
         
            +
                        self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
         
     | 
| 494 | 
         
            +
                    else:
         
     | 
| 495 | 
         
            +
                        self.norm = nn.LayerNorm(dim, eps=1e-6)
         
     | 
| 496 | 
         
            +
                    self.pwconv1 = nn.Linear(
         
     | 
| 497 | 
         
            +
                        dim, intermediate_dim
         
     | 
| 498 | 
         
            +
                    )  # pointwise/1x1 convs, implemented with linear layers
         
     | 
| 499 | 
         
            +
                    self.act = nn.GELU()
         
     | 
| 500 | 
         
            +
                    self.pwconv2 = nn.Linear(intermediate_dim, dim)
         
     | 
| 501 | 
         
            +
                    self.gamma = (
         
     | 
| 502 | 
         
            +
                        nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
         
     | 
| 503 | 
         
            +
                        if layer_scale_init_value > 0
         
     | 
| 504 | 
         
            +
                        else None
         
     | 
| 505 | 
         
            +
                    )
         
     | 
| 506 | 
         
            +
             
     | 
| 507 | 
         
            +
                def forward(
         
     | 
| 508 | 
         
            +
                    self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
         
     | 
| 509 | 
         
            +
                ) -> torch.Tensor:
         
     | 
| 510 | 
         
            +
                    residual = x
         
     | 
| 511 | 
         
            +
                    x = self.dwconv(x)
         
     | 
| 512 | 
         
            +
                    x = x.transpose(1, 2)  # (B, C, T) -> (B, T, C)
         
     | 
| 513 | 
         
            +
                    if self.adanorm:
         
     | 
| 514 | 
         
            +
                        assert cond_embedding_id is not None
         
     | 
| 515 | 
         
            +
                        x = self.norm(x, cond_embedding_id)
         
     | 
| 516 | 
         
            +
                    else:
         
     | 
| 517 | 
         
            +
                        x = self.norm(x)
         
     | 
| 518 | 
         
            +
                    x = self.pwconv1(x)
         
     | 
| 519 | 
         
            +
                    x = self.act(x)
         
     | 
| 520 | 
         
            +
                    x = self.pwconv2(x)
         
     | 
| 521 | 
         
            +
                    if self.gamma is not None:
         
     | 
| 522 | 
         
            +
                        x = self.gamma * x
         
     | 
| 523 | 
         
            +
                    x = x.transpose(1, 2)  # (B, T, C) -> (B, C, T)
         
     | 
| 524 | 
         
            +
             
     | 
| 525 | 
         
            +
                    x = residual + x
         
     | 
| 526 | 
         
            +
                    return x
         
     | 
| 527 | 
         
            +
             
     | 
| 528 | 
         
            +
             
     | 
| 529 | 
         
            +
            class AdaLayerNorm(nn.Module):
         
     | 
| 530 | 
         
            +
                """
         
     | 
| 531 | 
         
            +
                Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
         
     | 
| 532 | 
         
            +
             
     | 
| 533 | 
         
            +
                Args:
         
     | 
| 534 | 
         
            +
                    num_embeddings (int): Number of embeddings.
         
     | 
| 535 | 
         
            +
                    embedding_dim (int): Dimension of the embeddings.
         
     | 
| 536 | 
         
            +
                """
         
     | 
| 537 | 
         
            +
             
     | 
| 538 | 
         
            +
                def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6):
         
     | 
| 539 | 
         
            +
                    super().__init__()
         
     | 
| 540 | 
         
            +
                    self.eps = eps
         
     | 
| 541 | 
         
            +
                    self.dim = embedding_dim
         
     | 
| 542 | 
         
            +
                    self.scale = nn.Embedding(
         
     | 
| 543 | 
         
            +
                        num_embeddings=num_embeddings, embedding_dim=embedding_dim
         
     | 
| 544 | 
         
            +
                    )
         
     | 
| 545 | 
         
            +
                    self.shift = nn.Embedding(
         
     | 
| 546 | 
         
            +
                        num_embeddings=num_embeddings, embedding_dim=embedding_dim
         
     | 
| 547 | 
         
            +
                    )
         
     | 
| 548 | 
         
            +
                    torch.nn.init.ones_(self.scale.weight)
         
     | 
| 549 | 
         
            +
                    torch.nn.init.zeros_(self.shift.weight)
         
     | 
| 550 | 
         
            +
             
     | 
| 551 | 
         
            +
                def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor:
         
     | 
| 552 | 
         
            +
                    scale = self.scale(cond_embedding_id)
         
     | 
| 553 | 
         
            +
                    shift = self.shift(cond_embedding_id)
         
     | 
| 554 | 
         
            +
                    x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
         
     | 
| 555 | 
         
            +
                    x = x * scale + shift
         
     | 
| 556 | 
         
            +
                    return x
         
     | 
| 557 | 
         
            +
             
     | 
| 558 | 
         
            +
             
     | 
| 559 | 
         
            +
            class ResBlock1(nn.Module):
         
     | 
| 560 | 
         
            +
                """
         
     | 
| 561 | 
         
            +
                ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,
         
     | 
| 562 | 
         
            +
                but without upsampling layers.
         
     | 
| 563 | 
         
            +
             
     | 
| 564 | 
         
            +
                Args:
         
     | 
| 565 | 
         
            +
                    dim (int): Number of input channels.
         
     | 
| 566 | 
         
            +
                    kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
         
     | 
| 567 | 
         
            +
                    dilation (tuple[int], optional): Dilation factors for the dilated convolutions.
         
     | 
| 568 | 
         
            +
                        Defaults to (1, 3, 5).
         
     | 
| 569 | 
         
            +
                    lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.
         
     | 
| 570 | 
         
            +
                        Defaults to 0.1.
         
     | 
| 571 | 
         
            +
                    layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
         
     | 
| 572 | 
         
            +
                        Defaults to None.
         
     | 
| 573 | 
         
            +
                """
         
     | 
| 574 | 
         
            +
             
     | 
| 575 | 
         
            +
                def __init__(
         
     | 
| 576 | 
         
            +
                    self,
         
     | 
| 577 | 
         
            +
                    dim: int,
         
     | 
| 578 | 
         
            +
                    kernel_size: int = 3,
         
     | 
| 579 | 
         
            +
                    dilation: Tuple[int, int, int] = (1, 3, 5),
         
     | 
| 580 | 
         
            +
                    lrelu_slope: float = 0.1,
         
     | 
| 581 | 
         
            +
                    layer_scale_init_value: Optional[float] = None,
         
     | 
| 582 | 
         
            +
                ):
         
     | 
| 583 | 
         
            +
                    super().__init__()
         
     | 
| 584 | 
         
            +
                    self.lrelu_slope = lrelu_slope
         
     | 
| 585 | 
         
            +
                    self.convs1 = nn.ModuleList(
         
     | 
| 586 | 
         
            +
                        [
         
     | 
| 587 | 
         
            +
                            weight_norm(
         
     | 
| 588 | 
         
            +
                                nn.Conv1d(
         
     | 
| 589 | 
         
            +
                                    dim,
         
     | 
| 590 | 
         
            +
                                    dim,
         
     | 
| 591 | 
         
            +
                                    kernel_size,
         
     | 
| 592 | 
         
            +
                                    1,
         
     | 
| 593 | 
         
            +
                                    dilation=dilation[0],
         
     | 
| 594 | 
         
            +
                                    padding=self.get_padding(kernel_size, dilation[0]),
         
     | 
| 595 | 
         
            +
                                )
         
     | 
| 596 | 
         
            +
                            ),
         
     | 
| 597 | 
         
            +
                            weight_norm(
         
     | 
| 598 | 
         
            +
                                nn.Conv1d(
         
     | 
| 599 | 
         
            +
                                    dim,
         
     | 
| 600 | 
         
            +
                                    dim,
         
     | 
| 601 | 
         
            +
                                    kernel_size,
         
     | 
| 602 | 
         
            +
                                    1,
         
     | 
| 603 | 
         
            +
                                    dilation=dilation[1],
         
     | 
| 604 | 
         
            +
                                    padding=self.get_padding(kernel_size, dilation[1]),
         
     | 
| 605 | 
         
            +
                                )
         
     | 
| 606 | 
         
            +
                            ),
         
     | 
| 607 | 
         
            +
                            weight_norm(
         
     | 
| 608 | 
         
            +
                                nn.Conv1d(
         
     | 
| 609 | 
         
            +
                                    dim,
         
     | 
| 610 | 
         
            +
                                    dim,
         
     | 
| 611 | 
         
            +
                                    kernel_size,
         
     | 
| 612 | 
         
            +
                                    1,
         
     | 
| 613 | 
         
            +
                                    dilation=dilation[2],
         
     | 
| 614 | 
         
            +
                                    padding=self.get_padding(kernel_size, dilation[2]),
         
     | 
| 615 | 
         
            +
                                )
         
     | 
| 616 | 
         
            +
                            ),
         
     | 
| 617 | 
         
            +
                        ]
         
     | 
| 618 | 
         
            +
                    )
         
     | 
| 619 | 
         
            +
             
     | 
| 620 | 
         
            +
                    self.convs2 = nn.ModuleList(
         
     | 
| 621 | 
         
            +
                        [
         
     | 
| 622 | 
         
            +
                            weight_norm(
         
     | 
| 623 | 
         
            +
                                nn.Conv1d(
         
     | 
| 624 | 
         
            +
                                    dim,
         
     | 
| 625 | 
         
            +
                                    dim,
         
     | 
| 626 | 
         
            +
                                    kernel_size,
         
     | 
| 627 | 
         
            +
                                    1,
         
     | 
| 628 | 
         
            +
                                    dilation=1,
         
     | 
| 629 | 
         
            +
                                    padding=self.get_padding(kernel_size, 1),
         
     | 
| 630 | 
         
            +
                                )
         
     | 
| 631 | 
         
            +
                            ),
         
     | 
| 632 | 
         
            +
                            weight_norm(
         
     | 
| 633 | 
         
            +
                                nn.Conv1d(
         
     | 
| 634 | 
         
            +
                                    dim,
         
     | 
| 635 | 
         
            +
                                    dim,
         
     | 
| 636 | 
         
            +
                                    kernel_size,
         
     | 
| 637 | 
         
            +
                                    1,
         
     | 
| 638 | 
         
            +
                                    dilation=1,
         
     | 
| 639 | 
         
            +
                                    padding=self.get_padding(kernel_size, 1),
         
     | 
| 640 | 
         
            +
                                )
         
     | 
| 641 | 
         
            +
                            ),
         
     | 
| 642 | 
         
            +
                            weight_norm(
         
     | 
| 643 | 
         
            +
                                nn.Conv1d(
         
     | 
| 644 | 
         
            +
                                    dim,
         
     | 
| 645 | 
         
            +
                                    dim,
         
     | 
| 646 | 
         
            +
                                    kernel_size,
         
     | 
| 647 | 
         
            +
                                    1,
         
     | 
| 648 | 
         
            +
                                    dilation=1,
         
     | 
| 649 | 
         
            +
                                    padding=self.get_padding(kernel_size, 1),
         
     | 
| 650 | 
         
            +
                                )
         
     | 
| 651 | 
         
            +
                            ),
         
     | 
| 652 | 
         
            +
                        ]
         
     | 
| 653 | 
         
            +
                    )
         
     | 
| 654 | 
         
            +
             
     | 
| 655 | 
         
            +
                    self.gamma = nn.ParameterList(
         
     | 
| 656 | 
         
            +
                        [
         
     | 
| 657 | 
         
            +
                            (
         
     | 
| 658 | 
         
            +
                                nn.Parameter(
         
     | 
| 659 | 
         
            +
                                    layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
         
     | 
| 660 | 
         
            +
                                )
         
     | 
| 661 | 
         
            +
                                if layer_scale_init_value is not None
         
     | 
| 662 | 
         
            +
                                else None
         
     | 
| 663 | 
         
            +
                            ),
         
     | 
| 664 | 
         
            +
                            (
         
     | 
| 665 | 
         
            +
                                nn.Parameter(
         
     | 
| 666 | 
         
            +
                                    layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
         
     | 
| 667 | 
         
            +
                                )
         
     | 
| 668 | 
         
            +
                                if layer_scale_init_value is not None
         
     | 
| 669 | 
         
            +
                                else None
         
     | 
| 670 | 
         
            +
                            ),
         
     | 
| 671 | 
         
            +
                            (
         
     | 
| 672 | 
         
            +
                                nn.Parameter(
         
     | 
| 673 | 
         
            +
                                    layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
         
     | 
| 674 | 
         
            +
                                )
         
     | 
| 675 | 
         
            +
                                if layer_scale_init_value is not None
         
     | 
| 676 | 
         
            +
                                else None
         
     | 
| 677 | 
         
            +
                            ),
         
     | 
| 678 | 
         
            +
                        ]
         
     | 
| 679 | 
         
            +
                    )
         
     | 
| 680 | 
         
            +
             
     | 
| 681 | 
         
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         
     | 
| 682 | 
         
            +
                    for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma):
         
     | 
| 683 | 
         
            +
                        xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope)
         
     | 
| 684 | 
         
            +
                        xt = c1(xt)
         
     | 
| 685 | 
         
            +
                        xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope)
         
     | 
| 686 | 
         
            +
                        xt = c2(xt)
         
     | 
| 687 | 
         
            +
                        if gamma is not None:
         
     | 
| 688 | 
         
            +
                            xt = gamma * xt
         
     | 
| 689 | 
         
            +
                        x = xt + x
         
     | 
| 690 | 
         
            +
                    return x
         
     | 
| 691 | 
         
            +
             
     | 
| 692 | 
         
            +
                def remove_weight_norm(self):
         
     | 
| 693 | 
         
            +
                    for l in self.convs1:
         
     | 
| 694 | 
         
            +
                        remove_weight_norm(l)
         
     | 
| 695 | 
         
            +
                    for l in self.convs2:
         
     | 
| 696 | 
         
            +
                        remove_weight_norm(l)
         
     | 
| 697 | 
         
            +
             
     | 
| 698 | 
         
            +
                @staticmethod
         
     | 
| 699 | 
         
            +
                def get_padding(kernel_size: int, dilation: int = 1) -> int:
         
     | 
| 700 | 
         
            +
                    return int((kernel_size * dilation - dilation) / 2)
         
     | 
| 701 | 
         
            +
             
     | 
| 702 | 
         
            +
             
     | 
| 703 | 
         
            +
            class Backbone(nn.Module):
         
     | 
| 704 | 
         
            +
                """Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
         
     | 
| 705 | 
         
            +
             
     | 
| 706 | 
         
            +
                def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
         
     | 
| 707 | 
         
            +
                    """
         
     | 
| 708 | 
         
            +
                    Args:
         
     | 
| 709 | 
         
            +
                        x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
         
     | 
| 710 | 
         
            +
                                    C denotes output features, and L is the sequence length.
         
     | 
| 711 | 
         
            +
             
     | 
| 712 | 
         
            +
                    Returns:
         
     | 
| 713 | 
         
            +
                        Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
         
     | 
| 714 | 
         
            +
                                and H denotes the model dimension.
         
     | 
| 715 | 
         
            +
                    """
         
     | 
| 716 | 
         
            +
                    raise NotImplementedError("Subclasses must implement the forward method.")
         
     | 
| 717 | 
         
            +
             
     | 
| 718 | 
         
            +
             
     | 
| 719 | 
         
            +
            class VocosBackbone(Backbone):
         
     | 
| 720 | 
         
            +
                """
         
     | 
| 721 | 
         
            +
                Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
         
     | 
| 722 | 
         
            +
             
     | 
| 723 | 
         
            +
                Args:
         
     | 
| 724 | 
         
            +
                    input_channels (int): Number of input features channels.
         
     | 
| 725 | 
         
            +
                    dim (int): Hidden dimension of the model.
         
     | 
| 726 | 
         
            +
                    intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
         
     | 
| 727 | 
         
            +
                    num_layers (int): Number of ConvNeXtBlock layers.
         
     | 
| 728 | 
         
            +
                    layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
         
     | 
| 729 | 
         
            +
                    adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
         
     | 
| 730 | 
         
            +
                                                            None means non-conditional model. Defaults to None.
         
     | 
| 731 | 
         
            +
                """
         
     | 
| 732 | 
         
            +
             
     | 
| 733 | 
         
            +
                def __init__(
         
     | 
| 734 | 
         
            +
                    self,
         
     | 
| 735 | 
         
            +
                    input_channels: int,
         
     | 
| 736 | 
         
            +
                    dim: int,
         
     | 
| 737 | 
         
            +
                    intermediate_dim: int,
         
     | 
| 738 | 
         
            +
                    num_layers: int,
         
     | 
| 739 | 
         
            +
                    layer_scale_init_value: Optional[float] = None,
         
     | 
| 740 | 
         
            +
                    adanorm_num_embeddings: Optional[int] = None,
         
     | 
| 741 | 
         
            +
                ):
         
     | 
| 742 | 
         
            +
                    super().__init__()
         
     | 
| 743 | 
         
            +
                    self.input_channels = input_channels
         
     | 
| 744 | 
         
            +
                    self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3)
         
     | 
| 745 | 
         
            +
                    self.adanorm = adanorm_num_embeddings is not None
         
     | 
| 746 | 
         
            +
                    if adanorm_num_embeddings:
         
     | 
| 747 | 
         
            +
                        self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
         
     | 
| 748 | 
         
            +
                    else:
         
     | 
| 749 | 
         
            +
                        self.norm = nn.LayerNorm(dim, eps=1e-6)
         
     | 
| 750 | 
         
            +
                    layer_scale_init_value = layer_scale_init_value or 1 / num_layers
         
     | 
| 751 | 
         
            +
                    self.convnext = nn.ModuleList(
         
     | 
| 752 | 
         
            +
                        [
         
     | 
| 753 | 
         
            +
                            ConvNeXtBlock(
         
     | 
| 754 | 
         
            +
                                dim=dim,
         
     | 
| 755 | 
         
            +
                                intermediate_dim=intermediate_dim,
         
     | 
| 756 | 
         
            +
                                layer_scale_init_value=layer_scale_init_value,
         
     | 
| 757 | 
         
            +
                                adanorm_num_embeddings=adanorm_num_embeddings,
         
     | 
| 758 | 
         
            +
                            )
         
     | 
| 759 | 
         
            +
                            for _ in range(num_layers)
         
     | 
| 760 | 
         
            +
                        ]
         
     | 
| 761 | 
         
            +
                    )
         
     | 
| 762 | 
         
            +
                    self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
         
     | 
| 763 | 
         
            +
                    self.apply(self._init_weights)
         
     | 
| 764 | 
         
            +
             
     | 
| 765 | 
         
            +
                def _init_weights(self, m):
         
     | 
| 766 | 
         
            +
                    if isinstance(m, (nn.Conv1d, nn.Linear)):
         
     | 
| 767 | 
         
            +
                        nn.init.trunc_normal_(m.weight, std=0.02)
         
     | 
| 768 | 
         
            +
                        nn.init.constant_(m.bias, 0)
         
     | 
| 769 | 
         
            +
             
     | 
| 770 | 
         
            +
                def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
         
     | 
| 771 | 
         
            +
                    bandwidth_id = kwargs.get("bandwidth_id", None)
         
     | 
| 772 | 
         
            +
                    x = self.embed(x)
         
     | 
| 773 | 
         
            +
                    if self.adanorm:
         
     | 
| 774 | 
         
            +
                        assert bandwidth_id is not None
         
     | 
| 775 | 
         
            +
                        x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id)
         
     | 
| 776 | 
         
            +
                    else:
         
     | 
| 777 | 
         
            +
                        x = self.norm(x.transpose(1, 2))
         
     | 
| 778 | 
         
            +
                    x = x.transpose(1, 2)
         
     | 
| 779 | 
         
            +
                    for conv_block in self.convnext:
         
     | 
| 780 | 
         
            +
                        x = conv_block(x, cond_embedding_id=bandwidth_id)
         
     | 
| 781 | 
         
            +
                    x = self.final_layer_norm(x.transpose(1, 2))
         
     | 
| 782 | 
         
            +
                    return x
         
     | 
| 783 | 
         
            +
             
     | 
| 784 | 
         
            +
             
     | 
| 785 | 
         
            +
            class VocosResNetBackbone(Backbone):
         
     | 
| 786 | 
         
            +
                """
         
     | 
| 787 | 
         
            +
                Vocos backbone module built with ResBlocks.
         
     | 
| 788 | 
         
            +
             
     | 
| 789 | 
         
            +
                Args:
         
     | 
| 790 | 
         
            +
                    input_channels (int): Number of input features channels.
         
     | 
| 791 | 
         
            +
                    dim (int): Hidden dimension of the model.
         
     | 
| 792 | 
         
            +
                    num_blocks (int): Number of ResBlock1 blocks.
         
     | 
| 793 | 
         
            +
                    layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None.
         
     | 
| 794 | 
         
            +
                """
         
     | 
| 795 | 
         
            +
             
     | 
| 796 | 
         
            +
                def __init__(
         
     | 
| 797 | 
         
            +
                    self,
         
     | 
| 798 | 
         
            +
                    input_channels,
         
     | 
| 799 | 
         
            +
                    dim,
         
     | 
| 800 | 
         
            +
                    num_blocks,
         
     | 
| 801 | 
         
            +
                    layer_scale_init_value=None,
         
     | 
| 802 | 
         
            +
                ):
         
     | 
| 803 | 
         
            +
                    super().__init__()
         
     | 
| 804 | 
         
            +
                    self.input_channels = input_channels
         
     | 
| 805 | 
         
            +
                    self.embed = weight_norm(
         
     | 
| 806 | 
         
            +
                        nn.Conv1d(input_channels, dim, kernel_size=3, padding=1)
         
     | 
| 807 | 
         
            +
                    )
         
     | 
| 808 | 
         
            +
                    layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3
         
     | 
| 809 | 
         
            +
                    self.resnet = nn.Sequential(
         
     | 
| 810 | 
         
            +
                        *[
         
     | 
| 811 | 
         
            +
                            ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value)
         
     | 
| 812 | 
         
            +
                            for _ in range(num_blocks)
         
     | 
| 813 | 
         
            +
                        ]
         
     | 
| 814 | 
         
            +
                    )
         
     | 
| 815 | 
         
            +
             
     | 
| 816 | 
         
            +
                def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
         
     | 
| 817 | 
         
            +
                    x = self.embed(x)
         
     | 
| 818 | 
         
            +
                    x = self.resnet(x)
         
     | 
| 819 | 
         
            +
                    x = x.transpose(1, 2)
         
     | 
| 820 | 
         
            +
                    return x
         
     | 
| 821 | 
         
            +
             
     | 
| 822 | 
         
            +
             
     | 
| 823 | 
         
            +
            class Vocos(nn.Module):
         
     | 
| 824 | 
         
            +
                def __init__(
         
     | 
| 825 | 
         
            +
                    self,
         
     | 
| 826 | 
         
            +
                    input_channels: int = 256,
         
     | 
| 827 | 
         
            +
                    dim: int = 384,
         
     | 
| 828 | 
         
            +
                    intermediate_dim: int = 1152,
         
     | 
| 829 | 
         
            +
                    num_layers: int = 8,
         
     | 
| 830 | 
         
            +
                    adanorm_num_embeddings: int = 4,
         
     | 
| 831 | 
         
            +
                    n_fft: int = 800,
         
     | 
| 832 | 
         
            +
                    hop_size: int = 200,
         
     | 
| 833 | 
         
            +
                    padding: str = "same",
         
     | 
| 834 | 
         
            +
                ):
         
     | 
| 835 | 
         
            +
                    super().__init__()
         
     | 
| 836 | 
         
            +
             
     | 
| 837 | 
         
            +
                    self.backbone = VocosBackbone(
         
     | 
| 838 | 
         
            +
                        input_channels=input_channels,
         
     | 
| 839 | 
         
            +
                        dim=dim,
         
     | 
| 840 | 
         
            +
                        intermediate_dim=intermediate_dim,
         
     | 
| 841 | 
         
            +
                        num_layers=num_layers,
         
     | 
| 842 | 
         
            +
                        adanorm_num_embeddings=adanorm_num_embeddings,
         
     | 
| 843 | 
         
            +
                    )
         
     | 
| 844 | 
         
            +
                    self.head = ISTFTHead(dim, n_fft, hop_size, padding)
         
     | 
| 845 | 
         
            +
             
     | 
| 846 | 
         
            +
                def forward(self, x):
         
     | 
| 847 | 
         
            +
                    x = self.backbone(x)
         
     | 
| 848 | 
         
            +
                    x = self.head(x)
         
     | 
| 849 | 
         
            +
             
     | 
| 850 | 
         
            +
                    return x[:, None, :]
         
     | 
    	
        models/codec/ns3_codec/README.md
    ADDED
    
    | 
         @@ -0,0 +1,216 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            ## FACodec: Speech Codec with Attribute Factorization used for NaturalSpeech 3
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            [](https://arxiv.org/pdf/2403.03100.pdf)
         
     | 
| 4 | 
         
            +
            [](https://speechresearch.github.io/naturalspeech3/)
         
     | 
| 5 | 
         
            +
            [](https://huggingface.co/amphion/naturalspeech3_facodec)
         
     | 
| 6 | 
         
            +
            [](https://huggingface.co/spaces/amphion/naturalspeech3_facodec)
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            ## Overview
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            FACodec is a core component of the advanced text-to-speech (TTS) model NaturalSpeech 3. FACodec converts complex speech waveform into disentangled subspaces representing speech attributes of content, prosody, timbre, and acoustic details and reconstruct high-quality speech waveform from these attributes. FACodec decomposes complex speech into subspaces representing different attributes, thus simplifying the modeling of speech representation.
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            Research can use FACodec to develop different modes of TTS models, such as non-autoregressive based discrete diffusion (NaturalSpeech 3) or autoregressive models (like VALL-E).
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            <br>
         
     | 
| 15 | 
         
            +
            <div align="center">
         
     | 
| 16 | 
         
            +
            <img src="../../../imgs/ns3/ns3_overview.png" width="65%">
         
     | 
| 17 | 
         
            +
            </div>
         
     | 
| 18 | 
         
            +
            <br>
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            <br>
         
     | 
| 21 | 
         
            +
            <div align="center">
         
     | 
| 22 | 
         
            +
            <img src="../../../imgs/ns3/ns3_facodec.png" width="100%">
         
     | 
| 23 | 
         
            +
            </div>
         
     | 
| 24 | 
         
            +
            <br>
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            ## Useage
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            Download the pre-trained FACodec model from HuggingFace: [Pretrained FACodec checkpoint](https://huggingface.co/amphion/naturalspeech3_facodec)
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
            Install Amphion
         
     | 
| 31 | 
         
            +
            ```bash
         
     | 
| 32 | 
         
            +
            git clone https://github.com/open-mmlab/Amphion.git
         
     | 
| 33 | 
         
            +
            ```
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            Few lines of code to use the pre-trained FACodec model
         
     | 
| 36 | 
         
            +
            ```python
         
     | 
| 37 | 
         
            +
            from Amphion.models.codec.ns3_codec import FACodecEncoder, FACodecDecoder
         
     | 
| 38 | 
         
            +
            from huggingface_hub import hf_hub_download
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
            fa_encoder = FACodecEncoder(
         
     | 
| 41 | 
         
            +
                ngf=32,
         
     | 
| 42 | 
         
            +
                up_ratios=[2, 4, 5, 5],
         
     | 
| 43 | 
         
            +
                out_channels=256,
         
     | 
| 44 | 
         
            +
            )
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
            fa_decoder = FACodecDecoder(
         
     | 
| 47 | 
         
            +
                in_channels=256,
         
     | 
| 48 | 
         
            +
                upsample_initial_channel=1024,
         
     | 
| 49 | 
         
            +
                ngf=32,
         
     | 
| 50 | 
         
            +
                up_ratios=[5, 5, 4, 2],
         
     | 
| 51 | 
         
            +
                vq_num_q_c=2,
         
     | 
| 52 | 
         
            +
                vq_num_q_p=1,
         
     | 
| 53 | 
         
            +
                vq_num_q_r=3,
         
     | 
| 54 | 
         
            +
                vq_dim=256,
         
     | 
| 55 | 
         
            +
                codebook_dim=8,
         
     | 
| 56 | 
         
            +
                codebook_size_prosody=10,
         
     | 
| 57 | 
         
            +
                codebook_size_content=10,
         
     | 
| 58 | 
         
            +
                codebook_size_residual=10,
         
     | 
| 59 | 
         
            +
                use_gr_x_timbre=True,
         
     | 
| 60 | 
         
            +
                use_gr_residual_f0=True,
         
     | 
| 61 | 
         
            +
                use_gr_residual_phone=True,
         
     | 
| 62 | 
         
            +
            )
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
            encoder_ckpt = hf_hub_download(repo_id="amphion/naturalspeech3_facodec", filename="ns3_facodec_encoder.bin")
         
     | 
| 65 | 
         
            +
            decoder_ckpt = hf_hub_download(repo_id="amphion/naturalspeech3_facodec", filename="ns3_facodec_decoder.bin")
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
            fa_encoder.load_state_dict(torch.load(encoder_ckpt))
         
     | 
| 68 | 
         
            +
            fa_decoder.load_state_dict(torch.load(decoder_ckpt))
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
            fa_encoder.eval()
         
     | 
| 71 | 
         
            +
            fa_decoder.eval()
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
            ```
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
            Inference
         
     | 
| 76 | 
         
            +
            ```python
         
     | 
| 77 | 
         
            +
            test_wav_path = "test.wav"
         
     | 
| 78 | 
         
            +
            test_wav = librosa.load(test_wav_path, sr=16000)[0]
         
     | 
| 79 | 
         
            +
            test_wav = torch.from_numpy(test_wav).float()
         
     | 
| 80 | 
         
            +
            test_wav = test_wav.unsqueeze(0).unsqueeze(0)
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
            with torch.no_grad():
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
                # encode
         
     | 
| 85 | 
         
            +
                enc_out = fa_encoder(test_wav)
         
     | 
| 86 | 
         
            +
                print(enc_out.shape)
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                # quantize
         
     | 
| 89 | 
         
            +
                vq_post_emb, vq_id, _, quantized, spk_embs = fa_decoder(enc_out, eval_vq=False, vq=True)
         
     | 
| 90 | 
         
            +
                
         
     | 
| 91 | 
         
            +
                # latent after quantization
         
     | 
| 92 | 
         
            +
                print(vq_post_emb.shape)
         
     | 
| 93 | 
         
            +
                
         
     | 
| 94 | 
         
            +
                # codes
         
     | 
| 95 | 
         
            +
                print("vq id shape:", vq_id.shape)
         
     | 
| 96 | 
         
            +
                
         
     | 
| 97 | 
         
            +
                # get prosody code
         
     | 
| 98 | 
         
            +
                prosody_code = vq_id[:1]
         
     | 
| 99 | 
         
            +
                print("prosody code shape:", prosody_code.shape)
         
     | 
| 100 | 
         
            +
                
         
     | 
| 101 | 
         
            +
                # get content code
         
     | 
| 102 | 
         
            +
                cotent_code = vq_id[1:3]
         
     | 
| 103 | 
         
            +
                print("content code shape:", cotent_code.shape)
         
     | 
| 104 | 
         
            +
                
         
     | 
| 105 | 
         
            +
                # get residual code (acoustic detail codes)
         
     | 
| 106 | 
         
            +
                residual_code = vq_id[3:]
         
     | 
| 107 | 
         
            +
                print("residual code shape:", residual_code.shape)
         
     | 
| 108 | 
         
            +
                
         
     | 
| 109 | 
         
            +
                # speaker embedding
         
     | 
| 110 | 
         
            +
                print("speaker embedding shape:", spk_embs.shape)
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                # decode (recommand)
         
     | 
| 113 | 
         
            +
                recon_wav = fa_decoder.inference(vq_post_emb, spk_embs)
         
     | 
| 114 | 
         
            +
                print(recon_wav.shape)
         
     | 
| 115 | 
         
            +
                sf.write("recon.wav", recon_wav[0][0].cpu().numpy(), 16000)
         
     | 
| 116 | 
         
            +
            ```
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
            FACodec can achieve zero-shot voice conversion with FACodecEncoderV2/FACodecDecoderV2 or FACodecRedecoder
         
     | 
| 119 | 
         
            +
            ```python
         
     | 
| 120 | 
         
            +
            from Amphion.models.codec.ns3_codec import FACodecEncoderV2, FACodecDecoderV2
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
            # Same parameters as FACodecEncoder/FACodecDecoder
         
     | 
| 123 | 
         
            +
            fa_encoder_v2 = FACodecEncoderV2(...)
         
     | 
| 124 | 
         
            +
            fa_decoder_v2 = FACodecDecoderV2(...)
         
     | 
| 125 | 
         
            +
             
     | 
| 126 | 
         
            +
            encoder_v2_ckpt = hf_hub_download(repo_id="amphion/naturalspeech3_facodec", filename="ns3_facodec_encoder_v2.bin")
         
     | 
| 127 | 
         
            +
            decoder_v2_ckpt = hf_hub_download(repo_id="amphion/naturalspeech3_facodec", filename="ns3_facodec_decoder_v2.bin")
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
            fa_encoder_v2.load_state_dict(torch.load(encoder_v2_ckpt))
         
     | 
| 130 | 
         
            +
            fa_decoder_v2.load_state_dict(torch.load(decoder_v2_ckpt))
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
            with torch.no_grad():
         
     | 
| 133 | 
         
            +
              enc_out_a = fa_encoder_v2(wav_a)
         
     | 
| 134 | 
         
            +
              prosody_a = fa_encoder_v2.get_prosody_feature(wav_a)
         
     | 
| 135 | 
         
            +
              enc_out_b = fa_encoder_v2(wav_b)
         
     | 
| 136 | 
         
            +
              prosody_b = fa_encoder_v2.get_prosody_feature(wav_b)
         
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
              vq_post_emb_a, vq_id_a, _, quantized, spk_embs_a = fa_decoder_v2(
         
     | 
| 139 | 
         
            +
                  enc_out_a, prosody_a, eval_vq=False, vq=True
         
     | 
| 140 | 
         
            +
              )
         
     | 
| 141 | 
         
            +
              vq_post_emb_b, vq_id_b, _, quantized, spk_embs_b = fa_decoder_v2(
         
     | 
| 142 | 
         
            +
                  enc_out_b, prosody_b, eval_vq=False, vq=True
         
     | 
| 143 | 
         
            +
              )
         
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
              vq_post_emb_a_to_b = fa_decoder_v2.vq2emb(vq_id_a, use_residual=False)
         
     | 
| 146 | 
         
            +
              recon_wav_a_to_b = fa_decoder_v2.inference(vq_post_emb_a_to_b, spk_embs_b)
         
     | 
| 147 | 
         
            +
            ```
         
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
            or
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
            ```python
         
     | 
| 152 | 
         
            +
            from Amphion.models.codec.ns3_codec import FACodecRedecoder
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
            fa_redecoder = FACodecRedecoder()
         
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
            redecoder_ckpt = hf_hub_download(repo_id="amphion/naturalspeech3_facodec", filename="ns3_facodec_redecoder.bin")
         
     | 
| 157 | 
         
            +
             
     | 
| 158 | 
         
            +
            fa_redecoder.load_state_dict(torch.load(redecoder_ckpt))
         
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
            with torch.no_grad():
         
     | 
| 161 | 
         
            +
                enc_out_a = fa_encoder(wav_a)
         
     | 
| 162 | 
         
            +
                enc_out_b = fa_encoder(wav_b)
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
                vq_post_emb_a, vq_id_a, _, quantized_a, spk_embs_a = fa_decoder(enc_out_a, eval_vq=False, vq=True)
         
     | 
| 165 | 
         
            +
                vq_post_emb_b, vq_id_b, _, quantized_b, spk_embs_b = fa_decoder(enc_out_b, eval_vq=False, vq=True)
         
     | 
| 166 | 
         
            +
             
     | 
| 167 | 
         
            +
                # convert speaker
         
     | 
| 168 | 
         
            +
                vq_post_emb_a_to_b = fa_redecoder.vq2emb(vq_id_a, spk_embs_b, use_residual=False)
         
     | 
| 169 | 
         
            +
                recon_wav_a_to_b = fa_redecoder.inference(vq_post_emb_a_to_b, spk_embs_b)
         
     | 
| 170 | 
         
            +
             
     | 
| 171 | 
         
            +
                sf.write("recon_a_to_b.wav", recon_wav_a_to_b[0][0].cpu().numpy(), 16000)
         
     | 
| 172 | 
         
            +
            ```
         
     | 
| 173 | 
         
            +
             
     | 
| 174 | 
         
            +
            ## Q&A
         
     | 
| 175 | 
         
            +
             
     | 
| 176 | 
         
            +
            Q1: What audio sample rate does FACodec support? What is the hop size? How many codes will be generated for each frame?
         
     | 
| 177 | 
         
            +
             
     | 
| 178 | 
         
            +
            A1: FACodec supports 16KHz speech audio. The hop size is 200 samples, and (16000/200) * 6 (total number of codebooks) codes will be generated for each frame.
         
     | 
| 179 | 
         
            +
             
     | 
| 180 | 
         
            +
            Q2: Is it possible to train an autoregressive TTS model like VALL-E using FACodec?
         
     | 
| 181 | 
         
            +
             
     | 
| 182 | 
         
            +
            A2: Yes. In fact, the authors of NaturalSpeech 3 have already employ explore the autoregressive generative model for discrete token generation with FACodec. They use an autoregressive language model to generate prosody codes, followed by a non-autoregressive model to generate the remaining content and acoustic details codes.
         
     | 
| 183 | 
         
            +
             
     | 
| 184 | 
         
            +
            Q3: Is it possible to train a latent diffusion TTS model like NaturalSpeech2 using FACodec?
         
     | 
| 185 | 
         
            +
             
     | 
| 186 | 
         
            +
            A3: Yes. You can use the latent getted after quanzaition as the modelling target for the latent diffusion model.
         
     | 
| 187 | 
         
            +
             
     | 
| 188 | 
         
            +
            Q4: Can FACodec compress and reconstruct audio from other domains? Such as sound effects, music, etc.
         
     | 
| 189 | 
         
            +
             
     | 
| 190 | 
         
            +
            A4: Since FACodec is designed for speech, it may not be suitable for other audio domains. However, it is possible to use the FACodec model to compress and reconstruct audio from other domains, but the quality may not be as good as the original audio.
         
     | 
| 191 | 
         
            +
             
     | 
| 192 | 
         
            +
            Q5: Can FACodec be used for content feature for some other tasks like voice conversion?
         
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
            A5: I think the answer is yes. Researchers can use the content code of FACodec as the content feature for voice conversion. We hope to see more research in this direction.
         
     | 
| 195 | 
         
            +
             
     | 
| 196 | 
         
            +
            ## Citations
         
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
            +
            If you use our FACodec model, please cite the following paper:
         
     | 
| 199 | 
         
            +
             
     | 
| 200 | 
         
            +
            ```bibtex
         
     | 
| 201 | 
         
            +
            @article{ju2024naturalspeech,
         
     | 
| 202 | 
         
            +
              title={NaturalSpeech 3: Zero-Shot Speech Synthesis with Factorized Codec and Diffusion Models},
         
     | 
| 203 | 
         
            +
              author={Ju, Zeqian and Wang, Yuancheng and Shen, Kai and Tan, Xu and Xin, Detai and Yang, Dongchao and Liu, Yanqing and Leng, Yichong and Song, Kaitao and Tang, Siliang and others},
         
     | 
| 204 | 
         
            +
              journal={arXiv preprint arXiv:2403.03100},
         
     | 
| 205 | 
         
            +
              year={2024}
         
     | 
| 206 | 
         
            +
            }
         
     | 
| 207 | 
         
            +
             
     | 
| 208 | 
         
            +
            @article{zhang2023amphion,
         
     | 
| 209 | 
         
            +
                  title={Amphion: An Open-Source Audio, Music and Speech Generation Toolkit}, 
         
     | 
| 210 | 
         
            +
                  author={Xueyao Zhang and Liumeng Xue and Yicheng Gu and Yuancheng Wang and Haorui He and Chaoren Wang and Xi Chen and Zihao Fang and Haopeng Chen and Junan Zhang and Tze Ying Tang and Lexiao Zou and Mingxuan Wang and Jun Han and Kai Chen and Haizhou Li and Zhizheng Wu},
         
     | 
| 211 | 
         
            +
                  journal={arXiv},
         
     | 
| 212 | 
         
            +
                  year={2024},
         
     | 
| 213 | 
         
            +
                  volume={abs/2312.09911}
         
     | 
| 214 | 
         
            +
            }
         
     | 
| 215 | 
         
            +
            ```
         
     | 
| 216 | 
         
            +
             
     | 
    	
        models/codec/ns3_codec/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,6 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2023 Amphion.
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the MIT license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            from .facodec import *
         
     | 
    	
        models/codec/ns3_codec/alias_free_torch/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,5 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            from .filter import *
         
     | 
| 4 | 
         
            +
            from .resample import *
         
     | 
| 5 | 
         
            +
            from .act import *
         
     | 
    	
        models/codec/ns3_codec/alias_free_torch/act.py
    ADDED
    
    | 
         @@ -0,0 +1,29 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import torch.nn as nn
         
     | 
| 4 | 
         
            +
            from .resample import UpSample1d, DownSample1d
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            class Activation1d(nn.Module):
         
     | 
| 8 | 
         
            +
                def __init__(
         
     | 
| 9 | 
         
            +
                    self,
         
     | 
| 10 | 
         
            +
                    activation,
         
     | 
| 11 | 
         
            +
                    up_ratio: int = 2,
         
     | 
| 12 | 
         
            +
                    down_ratio: int = 2,
         
     | 
| 13 | 
         
            +
                    up_kernel_size: int = 12,
         
     | 
| 14 | 
         
            +
                    down_kernel_size: int = 12,
         
     | 
| 15 | 
         
            +
                ):
         
     | 
| 16 | 
         
            +
                    super().__init__()
         
     | 
| 17 | 
         
            +
                    self.up_ratio = up_ratio
         
     | 
| 18 | 
         
            +
                    self.down_ratio = down_ratio
         
     | 
| 19 | 
         
            +
                    self.act = activation
         
     | 
| 20 | 
         
            +
                    self.upsample = UpSample1d(up_ratio, up_kernel_size)
         
     | 
| 21 | 
         
            +
                    self.downsample = DownSample1d(down_ratio, down_kernel_size)
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
                # x: [B,C,T]
         
     | 
| 24 | 
         
            +
                def forward(self, x):
         
     | 
| 25 | 
         
            +
                    x = self.upsample(x)
         
     | 
| 26 | 
         
            +
                    x = self.act(x)
         
     | 
| 27 | 
         
            +
                    x = self.downsample(x)
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                    return x
         
     | 
    	
        models/codec/ns3_codec/alias_free_torch/filter.py
    ADDED
    
    | 
         @@ -0,0 +1,96 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import torch
         
     | 
| 4 | 
         
            +
            import torch.nn as nn
         
     | 
| 5 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 6 | 
         
            +
            import math
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            if "sinc" in dir(torch):
         
     | 
| 9 | 
         
            +
                sinc = torch.sinc
         
     | 
| 10 | 
         
            +
            else:
         
     | 
| 11 | 
         
            +
                # This code is adopted from adefossez's julius.core.sinc under the MIT License
         
     | 
| 12 | 
         
            +
                # https://adefossez.github.io/julius/julius/core.html
         
     | 
| 13 | 
         
            +
                def sinc(x: torch.Tensor):
         
     | 
| 14 | 
         
            +
                    """
         
     | 
| 15 | 
         
            +
                    Implementation of sinc, i.e. sin(pi * x) / (pi * x)
         
     | 
| 16 | 
         
            +
                    __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
         
     | 
| 17 | 
         
            +
                    """
         
     | 
| 18 | 
         
            +
                    return torch.where(
         
     | 
| 19 | 
         
            +
                        x == 0,
         
     | 
| 20 | 
         
            +
                        torch.tensor(1.0, device=x.device, dtype=x.dtype),
         
     | 
| 21 | 
         
            +
                        torch.sin(math.pi * x) / math.pi / x,
         
     | 
| 22 | 
         
            +
                    )
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
         
     | 
| 26 | 
         
            +
            # https://adefossez.github.io/julius/julius/lowpass.html
         
     | 
| 27 | 
         
            +
            def kaiser_sinc_filter1d(
         
     | 
| 28 | 
         
            +
                cutoff, half_width, kernel_size
         
     | 
| 29 | 
         
            +
            ):  # return filter [1,1,kernel_size]
         
     | 
| 30 | 
         
            +
                even = kernel_size % 2 == 0
         
     | 
| 31 | 
         
            +
                half_size = kernel_size // 2
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
                # For kaiser window
         
     | 
| 34 | 
         
            +
                delta_f = 4 * half_width
         
     | 
| 35 | 
         
            +
                A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
         
     | 
| 36 | 
         
            +
                if A > 50.0:
         
     | 
| 37 | 
         
            +
                    beta = 0.1102 * (A - 8.7)
         
     | 
| 38 | 
         
            +
                elif A >= 21.0:
         
     | 
| 39 | 
         
            +
                    beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
         
     | 
| 40 | 
         
            +
                else:
         
     | 
| 41 | 
         
            +
                    beta = 0.0
         
     | 
| 42 | 
         
            +
                window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
         
     | 
| 45 | 
         
            +
                if even:
         
     | 
| 46 | 
         
            +
                    time = torch.arange(-half_size, half_size) + 0.5
         
     | 
| 47 | 
         
            +
                else:
         
     | 
| 48 | 
         
            +
                    time = torch.arange(kernel_size) - half_size
         
     | 
| 49 | 
         
            +
                if cutoff == 0:
         
     | 
| 50 | 
         
            +
                    filter_ = torch.zeros_like(time)
         
     | 
| 51 | 
         
            +
                else:
         
     | 
| 52 | 
         
            +
                    filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
         
     | 
| 53 | 
         
            +
                    # Normalize filter to have sum = 1, otherwise we will have a small leakage
         
     | 
| 54 | 
         
            +
                    # of the constant component in the input signal.
         
     | 
| 55 | 
         
            +
                    filter_ /= filter_.sum()
         
     | 
| 56 | 
         
            +
                    filter = filter_.view(1, 1, kernel_size)
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                return filter
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
            class LowPassFilter1d(nn.Module):
         
     | 
| 62 | 
         
            +
                def __init__(
         
     | 
| 63 | 
         
            +
                    self,
         
     | 
| 64 | 
         
            +
                    cutoff=0.5,
         
     | 
| 65 | 
         
            +
                    half_width=0.6,
         
     | 
| 66 | 
         
            +
                    stride: int = 1,
         
     | 
| 67 | 
         
            +
                    padding: bool = True,
         
     | 
| 68 | 
         
            +
                    padding_mode: str = "replicate",
         
     | 
| 69 | 
         
            +
                    kernel_size: int = 12,
         
     | 
| 70 | 
         
            +
                ):
         
     | 
| 71 | 
         
            +
                    # kernel_size should be even number for stylegan3 setup,
         
     | 
| 72 | 
         
            +
                    # in this implementation, odd number is also possible.
         
     | 
| 73 | 
         
            +
                    super().__init__()
         
     | 
| 74 | 
         
            +
                    if cutoff < -0.0:
         
     | 
| 75 | 
         
            +
                        raise ValueError("Minimum cutoff must be larger than zero.")
         
     | 
| 76 | 
         
            +
                    if cutoff > 0.5:
         
     | 
| 77 | 
         
            +
                        raise ValueError("A cutoff above 0.5 does not make sense.")
         
     | 
| 78 | 
         
            +
                    self.kernel_size = kernel_size
         
     | 
| 79 | 
         
            +
                    self.even = kernel_size % 2 == 0
         
     | 
| 80 | 
         
            +
                    self.pad_left = kernel_size // 2 - int(self.even)
         
     | 
| 81 | 
         
            +
                    self.pad_right = kernel_size // 2
         
     | 
| 82 | 
         
            +
                    self.stride = stride
         
     | 
| 83 | 
         
            +
                    self.padding = padding
         
     | 
| 84 | 
         
            +
                    self.padding_mode = padding_mode
         
     | 
| 85 | 
         
            +
                    filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
         
     | 
| 86 | 
         
            +
                    self.register_buffer("filter", filter)
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                # input [B, C, T]
         
     | 
| 89 | 
         
            +
                def forward(self, x):
         
     | 
| 90 | 
         
            +
                    _, C, _ = x.shape
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                    if self.padding:
         
     | 
| 93 | 
         
            +
                        x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
         
     | 
| 94 | 
         
            +
                    out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
                    return out
         
     | 
    	
        models/codec/ns3_codec/alias_free_torch/resample.py
    ADDED
    
    | 
         @@ -0,0 +1,57 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import torch.nn as nn
         
     | 
| 4 | 
         
            +
            from torch.nn import functional as F
         
     | 
| 5 | 
         
            +
            from .filter import LowPassFilter1d
         
     | 
| 6 | 
         
            +
            from .filter import kaiser_sinc_filter1d
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            class UpSample1d(nn.Module):
         
     | 
| 10 | 
         
            +
                def __init__(self, ratio=2, kernel_size=None):
         
     | 
| 11 | 
         
            +
                    super().__init__()
         
     | 
| 12 | 
         
            +
                    self.ratio = ratio
         
     | 
| 13 | 
         
            +
                    self.kernel_size = (
         
     | 
| 14 | 
         
            +
                        int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
         
     | 
| 15 | 
         
            +
                    )
         
     | 
| 16 | 
         
            +
                    self.stride = ratio
         
     | 
| 17 | 
         
            +
                    self.pad = self.kernel_size // ratio - 1
         
     | 
| 18 | 
         
            +
                    self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
         
     | 
| 19 | 
         
            +
                    self.pad_right = (
         
     | 
| 20 | 
         
            +
                        self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
         
     | 
| 21 | 
         
            +
                    )
         
     | 
| 22 | 
         
            +
                    filter = kaiser_sinc_filter1d(
         
     | 
| 23 | 
         
            +
                        cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
         
     | 
| 24 | 
         
            +
                    )
         
     | 
| 25 | 
         
            +
                    self.register_buffer("filter", filter)
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                # x: [B, C, T]
         
     | 
| 28 | 
         
            +
                def forward(self, x):
         
     | 
| 29 | 
         
            +
                    _, C, _ = x.shape
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                    x = F.pad(x, (self.pad, self.pad), mode="replicate")
         
     | 
| 32 | 
         
            +
                    x = self.ratio * F.conv_transpose1d(
         
     | 
| 33 | 
         
            +
                        x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
         
     | 
| 34 | 
         
            +
                    )
         
     | 
| 35 | 
         
            +
                    x = x[..., self.pad_left : -self.pad_right]
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                    return x
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
            class DownSample1d(nn.Module):
         
     | 
| 41 | 
         
            +
                def __init__(self, ratio=2, kernel_size=None):
         
     | 
| 42 | 
         
            +
                    super().__init__()
         
     | 
| 43 | 
         
            +
                    self.ratio = ratio
         
     | 
| 44 | 
         
            +
                    self.kernel_size = (
         
     | 
| 45 | 
         
            +
                        int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
         
     | 
| 46 | 
         
            +
                    )
         
     | 
| 47 | 
         
            +
                    self.lowpass = LowPassFilter1d(
         
     | 
| 48 | 
         
            +
                        cutoff=0.5 / ratio,
         
     | 
| 49 | 
         
            +
                        half_width=0.6 / ratio,
         
     | 
| 50 | 
         
            +
                        stride=ratio,
         
     | 
| 51 | 
         
            +
                        kernel_size=self.kernel_size,
         
     | 
| 52 | 
         
            +
                    )
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                def forward(self, x):
         
     | 
| 55 | 
         
            +
                    xx = self.lowpass(x)
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                    return xx
         
     |