Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	Commit 
							
							·
						
						4300fed
	
0
								Parent(s):
							
							
Init
Browse filesThis view is limited to 50 files because it contains too many changes.  
							See raw diff
- .gitattributes +36 -0
- .gitignore +11 -0
- LICENSE +19 -0
- README.md +162 -0
- app.py +23 -0
- logo.png +0 -0
- melo/__init__.py +0 -0
- melo/api.py +113 -0
- melo/attentions.py +459 -0
- melo/commons.py +160 -0
- melo/download_utils.py +47 -0
- melo/mel_processing.py +174 -0
- melo/models.py +1038 -0
- melo/modules.py +598 -0
- melo/split_utils.py +131 -0
- melo/text/__init__.py +35 -0
- melo/text/chinese.py +199 -0
- melo/text/chinese_bert.py +107 -0
- melo/text/chinese_mix.py +253 -0
- melo/text/cleaner.py +36 -0
- melo/text/cleaner_multiling.py +110 -0
- melo/text/cmudict.rep +0 -0
- melo/text/cmudict_cache.pickle +3 -0
- melo/text/english.py +284 -0
- melo/text/english_bert.py +39 -0
- melo/text/english_utils/__init__.py +0 -0
- melo/text/english_utils/abbreviations.py +35 -0
- melo/text/english_utils/number_norm.py +97 -0
- melo/text/english_utils/time_norm.py +47 -0
- melo/text/es_phonemizer/__init__.py +0 -0
- melo/text/es_phonemizer/base.py +140 -0
- melo/text/es_phonemizer/cleaner.py +109 -0
- melo/text/es_phonemizer/es_symbols.json +79 -0
- melo/text/es_phonemizer/es_symbols.txt +1 -0
- melo/text/es_phonemizer/es_symbols_v2.json +83 -0
- melo/text/es_phonemizer/es_to_ipa.py +12 -0
- melo/text/es_phonemizer/gruut_wrapper.py +253 -0
- melo/text/es_phonemizer/punctuation.py +174 -0
- melo/text/es_phonemizer/spanish_symbols.txt +1 -0
- melo/text/es_phonemizer/test.ipynb +124 -0
- melo/text/fr_phonemizer/__init__.py +0 -0
- melo/text/fr_phonemizer/base.py +140 -0
- melo/text/fr_phonemizer/cleaner.py +122 -0
- melo/text/fr_phonemizer/en_symbols.json +78 -0
- melo/text/fr_phonemizer/fr_symbols.json +89 -0
- melo/text/fr_phonemizer/fr_to_ipa.py +30 -0
- melo/text/fr_phonemizer/french_abbreviations.py +48 -0
- melo/text/fr_phonemizer/french_symbols.txt +1 -0
- melo/text/fr_phonemizer/gruut_wrapper.py +258 -0
- melo/text/fr_phonemizer/punctuation.py +172 -0
    	
        .gitattributes
    ADDED
    
    | @@ -0,0 +1,36 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            *.7z filter=lfs diff=lfs merge=lfs -text
         | 
| 2 | 
            +
            *.arrow filter=lfs diff=lfs merge=lfs -text
         | 
| 3 | 
            +
            *.bin filter=lfs diff=lfs merge=lfs -text
         | 
| 4 | 
            +
            *.bz2 filter=lfs diff=lfs merge=lfs -text
         | 
| 5 | 
            +
            *.ckpt filter=lfs diff=lfs merge=lfs -text
         | 
| 6 | 
            +
            *.ftz filter=lfs diff=lfs merge=lfs -text
         | 
| 7 | 
            +
            *.gz filter=lfs diff=lfs merge=lfs -text
         | 
| 8 | 
            +
            *.h5 filter=lfs diff=lfs merge=lfs -text
         | 
| 9 | 
            +
            *.joblib filter=lfs diff=lfs merge=lfs -text
         | 
| 10 | 
            +
            *.lfs.* filter=lfs diff=lfs merge=lfs -text
         | 
| 11 | 
            +
            *.mlmodel filter=lfs diff=lfs merge=lfs -text
         | 
| 12 | 
            +
            *.model filter=lfs diff=lfs merge=lfs -text
         | 
| 13 | 
            +
            *.msgpack filter=lfs diff=lfs merge=lfs -text
         | 
| 14 | 
            +
            *.npy filter=lfs diff=lfs merge=lfs -text
         | 
| 15 | 
            +
            *.npz filter=lfs diff=lfs merge=lfs -text
         | 
| 16 | 
            +
            *.onnx filter=lfs diff=lfs merge=lfs -text
         | 
| 17 | 
            +
            *.ot filter=lfs diff=lfs merge=lfs -text
         | 
| 18 | 
            +
            *.parquet filter=lfs diff=lfs merge=lfs -text
         | 
| 19 | 
            +
            *.pb filter=lfs diff=lfs merge=lfs -text
         | 
| 20 | 
            +
            *.pickle filter=lfs diff=lfs merge=lfs -text
         | 
| 21 | 
            +
            *.pkl filter=lfs diff=lfs merge=lfs -text
         | 
| 22 | 
            +
            *.pt filter=lfs diff=lfs merge=lfs -text
         | 
| 23 | 
            +
            *.pth filter=lfs diff=lfs merge=lfs -text
         | 
| 24 | 
            +
            *.rar filter=lfs diff=lfs merge=lfs -text
         | 
| 25 | 
            +
            *.safetensors filter=lfs diff=lfs merge=lfs -text
         | 
| 26 | 
            +
            saved_model/**/* filter=lfs diff=lfs merge=lfs -text
         | 
| 27 | 
            +
            *.tar.* filter=lfs diff=lfs merge=lfs -text
         | 
| 28 | 
            +
            *.tar filter=lfs diff=lfs merge=lfs -text
         | 
| 29 | 
            +
            *.tflite filter=lfs diff=lfs merge=lfs -text
         | 
| 30 | 
            +
            *.tgz filter=lfs diff=lfs merge=lfs -text
         | 
| 31 | 
            +
            *.wasm filter=lfs diff=lfs merge=lfs -text
         | 
| 32 | 
            +
            *.xz filter=lfs diff=lfs merge=lfs -text
         | 
| 33 | 
            +
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 34 | 
            +
            *.zst filter=lfs diff=lfs merge=lfs -text
         | 
| 35 | 
            +
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
| 36 | 
            +
            melo/text/fr_phonemizer/example_ipa.txt filter=lfs diff=lfs merge=lfs -text
         | 
    	
        .gitignore
    ADDED
    
    | @@ -0,0 +1,11 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            __pycache__/
         | 
| 2 | 
            +
            .ipynb_checkpoints/
         | 
| 3 | 
            +
            basetts_outputs_use_bert/
         | 
| 4 | 
            +
            basetts_outputs/
         | 
| 5 | 
            +
            multilingual_ckpts
         | 
| 6 | 
            +
            basetts_outputs_package/
         | 
| 7 | 
            +
            build/
         | 
| 8 | 
            +
            *.egg-info/
         | 
| 9 | 
            +
            .DS_Store
         | 
| 10 | 
            +
            *.zip
         | 
| 11 | 
            +
            *.wav
         | 
    	
        LICENSE
    ADDED
    
    | @@ -0,0 +1,19 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            Copyright (c) 2024 MyShell.ai
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            Permission is hereby granted, free of charge, to any person obtaining a copy
         | 
| 4 | 
            +
            of this software and associated documentation files (the "Software"), to deal
         | 
| 5 | 
            +
            in the Software without restriction, including without limitation the rights
         | 
| 6 | 
            +
            to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
         | 
| 7 | 
            +
            copies of the Software, and to permit persons to whom the Software is
         | 
| 8 | 
            +
            furnished to do so, subject to the following conditions:
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            The above copyright notice and this permission notice shall be included in all
         | 
| 11 | 
            +
            copies or substantial portions of the Software.
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
         | 
| 14 | 
            +
            IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
         | 
| 15 | 
            +
            FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
         | 
| 16 | 
            +
            AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
         | 
| 17 | 
            +
            LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
         | 
| 18 | 
            +
            OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
         | 
| 19 | 
            +
            SOFTWARE.
         | 
    	
        README.md
    ADDED
    
    | @@ -0,0 +1,162 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            <div align="center">
         | 
| 2 | 
            +
              <div> </div>
         | 
| 3 | 
            +
              <img src="logo.png" width="200"/> 
         | 
| 4 | 
            +
            </div>
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            ## Introduction
         | 
| 7 | 
            +
            MeloTTS is a **high-quality multi-lingual** text-to-speech library by [MyShell.ai](https://myshell.ai). Supported languages include:
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            | Language | Example |
         | 
| 10 | 
            +
            | --- | --- |
         | 
| 11 | 
            +
            | English               | [Link](https://myshell-public-repo-hosting.s3.amazonaws.com/myshellttsbase/examples/en/EN-Default/speed_1.0/sent_000.wav) |
         | 
| 12 | 
            +
            | English (American)    | [Link](https://myshell-public-repo-hosting.s3.amazonaws.com/myshellttsbase/examples/en/EN-US/speed_1.0/sent_000.wav) |
         | 
| 13 | 
            +
            | English (British)     | [Link](https://myshell-public-repo-hosting.s3.amazonaws.com/myshellttsbase/examples/en/EN-BR/speed_1.0/sent_000.wav) |
         | 
| 14 | 
            +
            | English (Indian)       | [Link](https://myshell-public-repo-hosting.s3.amazonaws.com/myshellttsbase/examples/en/EN_INDIA/speed_1.0/sent_000.wav) |
         | 
| 15 | 
            +
            | English (Australian)  | [Link](https://myshell-public-repo-hosting.s3.amazonaws.com/myshellttsbase/examples/en/EN-AU/speed_1.0/sent_000.wav) |
         | 
| 16 | 
            +
            | Spanish               | [Link](https://myshell-public-repo-hosting.s3.amazonaws.com/myshellttsbase/examples/es/ES/speed_1.0/sent_000.wav) |
         | 
| 17 | 
            +
            | French                | [Link](https://myshell-public-repo-hosting.s3.amazonaws.com/myshellttsbase/examples/fr/FR/speed_1.0/sent_000.wav) |
         | 
| 18 | 
            +
            | Chinese (mix EN)      | [Link](https://myshell-public-repo-hosting.s3.amazonaws.com/myshellttsbase/examples/zh/ZH/speed_1.0/sent_008.wav) |
         | 
| 19 | 
            +
            | Japanese              | [Link](https://myshell-public-repo-hosting.s3.amazonaws.com/myshellttsbase/examples/jp/JP/speed_1.0/sent_000.wav) |
         | 
| 20 | 
            +
            | Korean                | [Link](https://myshell-public-repo-hosting.s3.amazonaws.com/myshellttsbase/examples/kr/KR/speed_1.0/sent_000.wav) |
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            Some other features include:
         | 
| 23 | 
            +
            - The Chinese speaker supports `mixed Chinese and English`.
         | 
| 24 | 
            +
            - Fast enough for `CPU real-time inference`.
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            ## Install on Linux
         | 
| 27 | 
            +
            ```bash
         | 
| 28 | 
            +
            git clone [email protected]:myshell-ai/MeloTTS.git
         | 
| 29 | 
            +
            cd MeloTTS
         | 
| 30 | 
            +
            pip install -e .
         | 
| 31 | 
            +
            python -m unidic download
         | 
| 32 | 
            +
            ```
         | 
| 33 | 
            +
            We welcome the open-source community to make this repo `Mac` and `Windows` compatible. If you find this repo useful, please consider contributing to the repo.
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            ## Usage
         | 
| 36 | 
            +
             | 
| 37 | 
            +
            ### English with Multi Accents
         | 
| 38 | 
            +
            ```python
         | 
| 39 | 
            +
            from melo.api import TTS
         | 
| 40 | 
            +
             | 
| 41 | 
            +
            # Speed is adjustable
         | 
| 42 | 
            +
            speed = 1.0
         | 
| 43 | 
            +
             | 
| 44 | 
            +
            # CPU is sufficient for real-time inference.
         | 
| 45 | 
            +
            # You can also change to cuda:0
         | 
| 46 | 
            +
            device = 'cpu'
         | 
| 47 | 
            +
             | 
| 48 | 
            +
            # English 
         | 
| 49 | 
            +
            text = "Did you ever hear a folk tale about a giant turtle?"
         | 
| 50 | 
            +
            model = TTS(language='EN', device=device)
         | 
| 51 | 
            +
            speaker_ids = model.hps.data.spk2id
         | 
| 52 | 
            +
             | 
| 53 | 
            +
            # Default accent
         | 
| 54 | 
            +
            output_path = 'en-default.wav'
         | 
| 55 | 
            +
            model.tts_to_file(text, speaker_ids['EN-Default'], output_path, speed=speed)
         | 
| 56 | 
            +
             | 
| 57 | 
            +
            # American accent
         | 
| 58 | 
            +
            output_path = 'en-us.wav'
         | 
| 59 | 
            +
            model.tts_to_file(text, speaker_ids['EN-US'], output_path, speed=speed)
         | 
| 60 | 
            +
             | 
| 61 | 
            +
            # British accent
         | 
| 62 | 
            +
            output_path = 'en-br.wav'
         | 
| 63 | 
            +
            model.tts_to_file(text, speaker_ids['EN-BR'], output_path, speed=speed)
         | 
| 64 | 
            +
             | 
| 65 | 
            +
            # Indian accent
         | 
| 66 | 
            +
            output_path = 'en-india.wav'
         | 
| 67 | 
            +
            model.tts_to_file(text, speaker_ids['EN_INDIA'], output_path, speed=speed)
         | 
| 68 | 
            +
             | 
| 69 | 
            +
            # Australian accent
         | 
| 70 | 
            +
            output_path = 'en-au.wav'
         | 
| 71 | 
            +
            model.tts_to_file(text, speaker_ids['EN-AU'], output_path, speed=speed)
         | 
| 72 | 
            +
             | 
| 73 | 
            +
            ```
         | 
| 74 | 
            +
             | 
| 75 | 
            +
            ### Spanish
         | 
| 76 | 
            +
            ```python
         | 
| 77 | 
            +
            from melo.api import TTS
         | 
| 78 | 
            +
             | 
| 79 | 
            +
            # Speed is adjustable
         | 
| 80 | 
            +
            speed = 1.0
         | 
| 81 | 
            +
             | 
| 82 | 
            +
            # CPU is sufficient for real-time inference.
         | 
| 83 | 
            +
            # You can also change to cuda:0
         | 
| 84 | 
            +
            device = 'cpu'
         | 
| 85 | 
            +
             | 
| 86 | 
            +
            text = "El resplandor del sol acaricia las olas, pintando el cielo con una paleta deslumbrante."
         | 
| 87 | 
            +
            model = TTS(language='ES', device=device)
         | 
| 88 | 
            +
            speaker_ids = model.hps.data.spk2id
         | 
| 89 | 
            +
             | 
| 90 | 
            +
            output_path = 'es.wav'
         | 
| 91 | 
            +
            model.tts_to_file(text, speaker_ids['ES'], output_path, speed=speed)
         | 
| 92 | 
            +
            ```
         | 
| 93 | 
            +
             | 
| 94 | 
            +
            ### French
         | 
| 95 | 
            +
            ```python
         | 
| 96 | 
            +
            from melo.api import TTS
         | 
| 97 | 
            +
             | 
| 98 | 
            +
            # Speed is adjustable
         | 
| 99 | 
            +
            speed = 1.0
         | 
| 100 | 
            +
            device = 'cpu' # or cuda:0
         | 
| 101 | 
            +
             | 
| 102 | 
            +
            text = "La lueur dorée du soleil caresse les vagues, peignant le ciel d'une palette éblouissante."
         | 
| 103 | 
            +
            model = TTS(language='FR', device=device)
         | 
| 104 | 
            +
            speaker_ids = model.hps.data.spk2id
         | 
| 105 | 
            +
             | 
| 106 | 
            +
            output_path = 'fr.wav'
         | 
| 107 | 
            +
            model.tts_to_file(text, speaker_ids['FR'], output_path, speed=speed)
         | 
| 108 | 
            +
            ```
         | 
| 109 | 
            +
             | 
| 110 | 
            +
            ### Chinese
         | 
| 111 | 
            +
            ```python
         | 
| 112 | 
            +
            from melo.api import TTS
         | 
| 113 | 
            +
             | 
| 114 | 
            +
            # Speed is adjustable
         | 
| 115 | 
            +
            speed = 1.0
         | 
| 116 | 
            +
            device = 'cpu' # or cuda:0
         | 
| 117 | 
            +
             | 
| 118 | 
            +
            text = "我最近在学习machine learning,希望能够在未来的artificial intelligence领域有所建树。"
         | 
| 119 | 
            +
            model = TTS(language='ZH', device=device)
         | 
| 120 | 
            +
            speaker_ids = model.hps.data.spk2id
         | 
| 121 | 
            +
             | 
| 122 | 
            +
            output_path = 'zh.wav'
         | 
| 123 | 
            +
            model.tts_to_file(text, speaker_ids['ZH'], output_path, speed=speed)
         | 
| 124 | 
            +
            ```
         | 
| 125 | 
            +
             | 
| 126 | 
            +
            ### Japanese
         | 
| 127 | 
            +
            ```python
         | 
| 128 | 
            +
            from melo.api import TTS
         | 
| 129 | 
            +
             | 
| 130 | 
            +
            # Speed is adjustable
         | 
| 131 | 
            +
            speed = 1.0
         | 
| 132 | 
            +
            device = 'cpu' # or cuda:0
         | 
| 133 | 
            +
             | 
| 134 | 
            +
            text = "彼は毎朝ジョギングをして体を健康に保っています。"
         | 
| 135 | 
            +
            model = TTS(language='JP', device=device)
         | 
| 136 | 
            +
            speaker_ids = model.hps.data.spk2id
         | 
| 137 | 
            +
             | 
| 138 | 
            +
            output_path = 'jp.wav'
         | 
| 139 | 
            +
            model.tts_to_file(text, speaker_ids['JP'], output_path, speed=speed)
         | 
| 140 | 
            +
            ```
         | 
| 141 | 
            +
             | 
| 142 | 
            +
            ### Korean
         | 
| 143 | 
            +
            ```python
         | 
| 144 | 
            +
            from melo.api import TTS
         | 
| 145 | 
            +
             | 
| 146 | 
            +
            # Speed is adjustable
         | 
| 147 | 
            +
            speed = 1.0
         | 
| 148 | 
            +
            device = 'cpu' # or cuda:0
         | 
| 149 | 
            +
             | 
| 150 | 
            +
            text = "안녕하세요! 오늘은 날씨가 정말 좋네요."
         | 
| 151 | 
            +
            model = TTS(language='KR', device=device)
         | 
| 152 | 
            +
            speaker_ids = model.hps.data.spk2id
         | 
| 153 | 
            +
             | 
| 154 | 
            +
            output_path = 'kr.wav'
         | 
| 155 | 
            +
            model.tts_to_file(text, speaker_ids['KR'], output_path, speed=speed)
         | 
| 156 | 
            +
            ```
         | 
| 157 | 
            +
             | 
| 158 | 
            +
            ## License
         | 
| 159 | 
            +
            This library is under MIT License. Free for both commercial and non-commercial use.
         | 
| 160 | 
            +
             | 
| 161 | 
            +
            ## Acknowledgement
         | 
| 162 | 
            +
            This implementation is based on several excellent projects, [TTS](https://github.com/coqui-ai/TTS), [VITS](https://github.com/jaywalnut310/vits), [VITS2](https://github.com/daniilrobnikov/vits2) and [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2). We appreciate their awesome work!
         | 
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,23 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import gradio as gr
         | 
| 2 | 
            +
            import os, torch, io
         | 
| 3 | 
            +
            os.system('python -m unidic download')
         | 
| 4 | 
            +
            from melo.api import TTS
         | 
| 5 | 
            +
            speed = 1.0
         | 
| 6 | 
            +
            import tempfile
         | 
| 7 | 
            +
            device = 'cuda' if torch.cuda.is_available() else 'cpu'
         | 
| 8 | 
            +
            model = TTS(language='EN', device=device)
         | 
| 9 | 
            +
            speaker_ids = model.hps.data.spk2id
         | 
| 10 | 
            +
            def synthesize(speaker, text, speed=1.0):
         | 
| 11 | 
            +
                with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f:
         | 
| 12 | 
            +
                    model.tts_to_file(text, speaker_ids[speaker], f.name, speed=speed)
         | 
| 13 | 
            +
                    return f.name
         | 
| 14 | 
            +
            with gr.Blocks() as demo:
         | 
| 15 | 
            +
                gr.Markdown('# MeloTTS\n\nAn unofficial demo of [MeloTTS](https://github.com/myshell-ai/MeloTTS) from MyShell AI. MeloTTS is a permissively licensed (MIT) SOTA multi-speaker TTS model.\n\nI am not affiliated with MyShell AI in any way.\n\nThis demo currently only supports English, but the model itself supports other languages.')
         | 
| 16 | 
            +
                with gr.Group():
         | 
| 17 | 
            +
                    speaker = gr.Dropdown(speaker_ids.keys(), interactive=True, value='EN-Default', label='Speaker')
         | 
| 18 | 
            +
                    speed = gr.Slider(label='Speed', minimum=0.1, maximum=3.0, value=1.0, interactive=True)
         | 
| 19 | 
            +
                    text = gr.Textbox(label="Text to speak", value='The field of text to speech has seen rapid development recently')
         | 
| 20 | 
            +
                btn = gr.Button('Synthesize', variant='primary')
         | 
| 21 | 
            +
                aud = gr.Audio(interactive=False)
         | 
| 22 | 
            +
                btn.click(synthesize, inputs=[speaker, text, speed], outputs=[aud])
         | 
| 23 | 
            +
            demo.queue(api_open=False).launch(show_api=False)
         | 
    	
        logo.png
    ADDED
    
    |   | 
    	
        melo/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        melo/api.py
    ADDED
    
    | @@ -0,0 +1,113 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import re
         | 
| 3 | 
            +
            import json
         | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            import librosa
         | 
| 6 | 
            +
            import soundfile
         | 
| 7 | 
            +
            import torchaudio
         | 
| 8 | 
            +
            import numpy as np
         | 
| 9 | 
            +
            import torch.nn as nn
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            from . import utils
         | 
| 12 | 
            +
            from . import commons
         | 
| 13 | 
            +
            from .models import SynthesizerTrn
         | 
| 14 | 
            +
            from .split_utils import split_sentence
         | 
| 15 | 
            +
            from .mel_processing import spectrogram_torch, spectrogram_torch_conv
         | 
| 16 | 
            +
            from .download_utils import load_or_download_config, load_or_download_model
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            class TTS(nn.Module):
         | 
| 19 | 
            +
                def __init__(self, 
         | 
| 20 | 
            +
                            language,
         | 
| 21 | 
            +
                            device='cuda:0'):
         | 
| 22 | 
            +
                    super().__init__()
         | 
| 23 | 
            +
                    if 'cuda' in device:
         | 
| 24 | 
            +
                        assert torch.cuda.is_available()
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                    # config_path = 
         | 
| 27 | 
            +
                    hps = load_or_download_config(language)
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                    num_languages = hps.num_languages
         | 
| 30 | 
            +
                    num_tones = hps.num_tones
         | 
| 31 | 
            +
                    symbols = hps.symbols
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                    model = SynthesizerTrn(
         | 
| 34 | 
            +
                        len(symbols),
         | 
| 35 | 
            +
                        hps.data.filter_length // 2 + 1,
         | 
| 36 | 
            +
                        hps.train.segment_size // hps.data.hop_length,
         | 
| 37 | 
            +
                        n_speakers=hps.data.n_speakers,
         | 
| 38 | 
            +
                        num_tones=num_tones,
         | 
| 39 | 
            +
                        num_languages=num_languages,
         | 
| 40 | 
            +
                        **hps.model,
         | 
| 41 | 
            +
                    ).to(device)
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                    model.eval()
         | 
| 44 | 
            +
                    self.model = model
         | 
| 45 | 
            +
                    self.symbol_to_id = {s: i for i, s in enumerate(symbols)}
         | 
| 46 | 
            +
                    self.hps = hps
         | 
| 47 | 
            +
                    self.device = device
         | 
| 48 | 
            +
                
         | 
| 49 | 
            +
                    # load state_dict
         | 
| 50 | 
            +
                    checkpoint_dict = load_or_download_model(language, device)
         | 
| 51 | 
            +
                    self.model.load_state_dict(checkpoint_dict['model'], strict=True)
         | 
| 52 | 
            +
                    
         | 
| 53 | 
            +
                    language = language.split('_')[0]
         | 
| 54 | 
            +
                    self.language = 'ZH_MIX_EN' if language == 'ZH' else language # we support a ZH_MIX_EN model
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                @staticmethod
         | 
| 57 | 
            +
                def audio_numpy_concat(segment_data_list, sr, speed=1.):
         | 
| 58 | 
            +
                    audio_segments = []
         | 
| 59 | 
            +
                    for segment_data in segment_data_list:
         | 
| 60 | 
            +
                        audio_segments += segment_data.reshape(-1).tolist()
         | 
| 61 | 
            +
                        audio_segments += [0] * int((sr * 0.05) / speed)
         | 
| 62 | 
            +
                    audio_segments = np.array(audio_segments).astype(np.float32)
         | 
| 63 | 
            +
                    return audio_segments
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                @staticmethod
         | 
| 66 | 
            +
                def split_sentences_into_pieces(text, language):
         | 
| 67 | 
            +
                    texts = split_sentence(text, language_str=language)
         | 
| 68 | 
            +
                    print(" > Text splitted to sentences.")
         | 
| 69 | 
            +
                    print('\n'.join(texts))
         | 
| 70 | 
            +
                    print(" > ===========================")
         | 
| 71 | 
            +
                    return texts
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                def tts_to_file(self, text, speaker_id, output_path=None, sdp_ratio=0.2, noise_scale=0.6, noise_scale_w=0.8, speed=1.0):
         | 
| 74 | 
            +
                    language = self.language
         | 
| 75 | 
            +
                    texts = self.split_sentences_into_pieces(text, language)
         | 
| 76 | 
            +
                    audio_list = []
         | 
| 77 | 
            +
                    for t in texts:
         | 
| 78 | 
            +
                        if language in ['EN', 'ZH_MIX_EN']:
         | 
| 79 | 
            +
                            t = re.sub(r'([a-z])([A-Z])', r'\1 \2', t)
         | 
| 80 | 
            +
                        device = self.device
         | 
| 81 | 
            +
                        bert, ja_bert, phones, tones, lang_ids = utils.get_text_for_tts_infer(t, language, self.hps, device, self.symbol_to_id)
         | 
| 82 | 
            +
                        with torch.no_grad():
         | 
| 83 | 
            +
                            x_tst = phones.to(device).unsqueeze(0)
         | 
| 84 | 
            +
                            tones = tones.to(device).unsqueeze(0)
         | 
| 85 | 
            +
                            lang_ids = lang_ids.to(device).unsqueeze(0)
         | 
| 86 | 
            +
                            bert = bert.to(device).unsqueeze(0)
         | 
| 87 | 
            +
                            ja_bert = ja_bert.to(device).unsqueeze(0)
         | 
| 88 | 
            +
                            x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
         | 
| 89 | 
            +
                            del phones
         | 
| 90 | 
            +
                            speakers = torch.LongTensor([speaker_id]).to(device)
         | 
| 91 | 
            +
                            audio = self.model.infer(
         | 
| 92 | 
            +
                                    x_tst,
         | 
| 93 | 
            +
                                    x_tst_lengths,
         | 
| 94 | 
            +
                                    speakers,
         | 
| 95 | 
            +
                                    tones,
         | 
| 96 | 
            +
                                    lang_ids,
         | 
| 97 | 
            +
                                    bert,
         | 
| 98 | 
            +
                                    ja_bert,
         | 
| 99 | 
            +
                                    sdp_ratio=sdp_ratio,
         | 
| 100 | 
            +
                                    noise_scale=noise_scale,
         | 
| 101 | 
            +
                                    noise_scale_w=noise_scale_w,
         | 
| 102 | 
            +
                                    length_scale=1. / speed,
         | 
| 103 | 
            +
                                )[0][0, 0].data.cpu().float().numpy()
         | 
| 104 | 
            +
                            del x_tst, tones, lang_ids, bert, ja_bert, x_tst_lengths, speakers
         | 
| 105 | 
            +
                            # 
         | 
| 106 | 
            +
                        audio_list.append(audio)
         | 
| 107 | 
            +
                    torch.cuda.empty_cache()
         | 
| 108 | 
            +
                    audio = self.audio_numpy_concat(audio_list, sr=self.hps.data.sampling_rate, speed=speed)
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                    if output_path is None:
         | 
| 111 | 
            +
                        return audio
         | 
| 112 | 
            +
                    else:
         | 
| 113 | 
            +
                        soundfile.write(output_path, audio, self.hps.data.sampling_rate)
         | 
    	
        melo/attentions.py
    ADDED
    
    | @@ -0,0 +1,459 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import math
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            from torch import nn
         | 
| 4 | 
            +
            from torch.nn import functional as F
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from . import commons
         | 
| 7 | 
            +
            import logging
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            class LayerNorm(nn.Module):
         | 
| 13 | 
            +
                def __init__(self, channels, eps=1e-5):
         | 
| 14 | 
            +
                    super().__init__()
         | 
| 15 | 
            +
                    self.channels = channels
         | 
| 16 | 
            +
                    self.eps = eps
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                    self.gamma = nn.Parameter(torch.ones(channels))
         | 
| 19 | 
            +
                    self.beta = nn.Parameter(torch.zeros(channels))
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                def forward(self, x):
         | 
| 22 | 
            +
                    x = x.transpose(1, -1)
         | 
| 23 | 
            +
                    x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
         | 
| 24 | 
            +
                    return x.transpose(1, -1)
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            @torch.jit.script
         | 
| 28 | 
            +
            def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
         | 
| 29 | 
            +
                n_channels_int = n_channels[0]
         | 
| 30 | 
            +
                in_act = input_a + input_b
         | 
| 31 | 
            +
                t_act = torch.tanh(in_act[:, :n_channels_int, :])
         | 
| 32 | 
            +
                s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
         | 
| 33 | 
            +
                acts = t_act * s_act
         | 
| 34 | 
            +
                return acts
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            class Encoder(nn.Module):
         | 
| 38 | 
            +
                def __init__(
         | 
| 39 | 
            +
                    self,
         | 
| 40 | 
            +
                    hidden_channels,
         | 
| 41 | 
            +
                    filter_channels,
         | 
| 42 | 
            +
                    n_heads,
         | 
| 43 | 
            +
                    n_layers,
         | 
| 44 | 
            +
                    kernel_size=1,
         | 
| 45 | 
            +
                    p_dropout=0.0,
         | 
| 46 | 
            +
                    window_size=4,
         | 
| 47 | 
            +
                    isflow=True,
         | 
| 48 | 
            +
                    **kwargs
         | 
| 49 | 
            +
                ):
         | 
| 50 | 
            +
                    super().__init__()
         | 
| 51 | 
            +
                    self.hidden_channels = hidden_channels
         | 
| 52 | 
            +
                    self.filter_channels = filter_channels
         | 
| 53 | 
            +
                    self.n_heads = n_heads
         | 
| 54 | 
            +
                    self.n_layers = n_layers
         | 
| 55 | 
            +
                    self.kernel_size = kernel_size
         | 
| 56 | 
            +
                    self.p_dropout = p_dropout
         | 
| 57 | 
            +
                    self.window_size = window_size
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                    self.cond_layer_idx = self.n_layers
         | 
| 60 | 
            +
                    if "gin_channels" in kwargs:
         | 
| 61 | 
            +
                        self.gin_channels = kwargs["gin_channels"]
         | 
| 62 | 
            +
                        if self.gin_channels != 0:
         | 
| 63 | 
            +
                            self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels)
         | 
| 64 | 
            +
                            self.cond_layer_idx = (
         | 
| 65 | 
            +
                                kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
         | 
| 66 | 
            +
                            )
         | 
| 67 | 
            +
                            assert (
         | 
| 68 | 
            +
                                self.cond_layer_idx < self.n_layers
         | 
| 69 | 
            +
                            ), "cond_layer_idx should be less than n_layers"
         | 
| 70 | 
            +
                    self.drop = nn.Dropout(p_dropout)
         | 
| 71 | 
            +
                    self.attn_layers = nn.ModuleList()
         | 
| 72 | 
            +
                    self.norm_layers_1 = nn.ModuleList()
         | 
| 73 | 
            +
                    self.ffn_layers = nn.ModuleList()
         | 
| 74 | 
            +
                    self.norm_layers_2 = nn.ModuleList()
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                    for i in range(self.n_layers):
         | 
| 77 | 
            +
                        self.attn_layers.append(
         | 
| 78 | 
            +
                            MultiHeadAttention(
         | 
| 79 | 
            +
                                hidden_channels,
         | 
| 80 | 
            +
                                hidden_channels,
         | 
| 81 | 
            +
                                n_heads,
         | 
| 82 | 
            +
                                p_dropout=p_dropout,
         | 
| 83 | 
            +
                                window_size=window_size,
         | 
| 84 | 
            +
                            )
         | 
| 85 | 
            +
                        )
         | 
| 86 | 
            +
                        self.norm_layers_1.append(LayerNorm(hidden_channels))
         | 
| 87 | 
            +
                        self.ffn_layers.append(
         | 
| 88 | 
            +
                            FFN(
         | 
| 89 | 
            +
                                hidden_channels,
         | 
| 90 | 
            +
                                hidden_channels,
         | 
| 91 | 
            +
                                filter_channels,
         | 
| 92 | 
            +
                                kernel_size,
         | 
| 93 | 
            +
                                p_dropout=p_dropout,
         | 
| 94 | 
            +
                            )
         | 
| 95 | 
            +
                        )
         | 
| 96 | 
            +
                        self.norm_layers_2.append(LayerNorm(hidden_channels))
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                def forward(self, x, x_mask, g=None):
         | 
| 99 | 
            +
                    attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
         | 
| 100 | 
            +
                    x = x * x_mask
         | 
| 101 | 
            +
                    for i in range(self.n_layers):
         | 
| 102 | 
            +
                        if i == self.cond_layer_idx and g is not None:
         | 
| 103 | 
            +
                            g = self.spk_emb_linear(g.transpose(1, 2))
         | 
| 104 | 
            +
                            g = g.transpose(1, 2)
         | 
| 105 | 
            +
                            x = x + g
         | 
| 106 | 
            +
                            x = x * x_mask
         | 
| 107 | 
            +
                        y = self.attn_layers[i](x, x, attn_mask)
         | 
| 108 | 
            +
                        y = self.drop(y)
         | 
| 109 | 
            +
                        x = self.norm_layers_1[i](x + y)
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                        y = self.ffn_layers[i](x, x_mask)
         | 
| 112 | 
            +
                        y = self.drop(y)
         | 
| 113 | 
            +
                        x = self.norm_layers_2[i](x + y)
         | 
| 114 | 
            +
                    x = x * x_mask
         | 
| 115 | 
            +
                    return x
         | 
| 116 | 
            +
             | 
| 117 | 
            +
             | 
| 118 | 
            +
            class Decoder(nn.Module):
         | 
| 119 | 
            +
                def __init__(
         | 
| 120 | 
            +
                    self,
         | 
| 121 | 
            +
                    hidden_channels,
         | 
| 122 | 
            +
                    filter_channels,
         | 
| 123 | 
            +
                    n_heads,
         | 
| 124 | 
            +
                    n_layers,
         | 
| 125 | 
            +
                    kernel_size=1,
         | 
| 126 | 
            +
                    p_dropout=0.0,
         | 
| 127 | 
            +
                    proximal_bias=False,
         | 
| 128 | 
            +
                    proximal_init=True,
         | 
| 129 | 
            +
                    **kwargs
         | 
| 130 | 
            +
                ):
         | 
| 131 | 
            +
                    super().__init__()
         | 
| 132 | 
            +
                    self.hidden_channels = hidden_channels
         | 
| 133 | 
            +
                    self.filter_channels = filter_channels
         | 
| 134 | 
            +
                    self.n_heads = n_heads
         | 
| 135 | 
            +
                    self.n_layers = n_layers
         | 
| 136 | 
            +
                    self.kernel_size = kernel_size
         | 
| 137 | 
            +
                    self.p_dropout = p_dropout
         | 
| 138 | 
            +
                    self.proximal_bias = proximal_bias
         | 
| 139 | 
            +
                    self.proximal_init = proximal_init
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                    self.drop = nn.Dropout(p_dropout)
         | 
| 142 | 
            +
                    self.self_attn_layers = nn.ModuleList()
         | 
| 143 | 
            +
                    self.norm_layers_0 = nn.ModuleList()
         | 
| 144 | 
            +
                    self.encdec_attn_layers = nn.ModuleList()
         | 
| 145 | 
            +
                    self.norm_layers_1 = nn.ModuleList()
         | 
| 146 | 
            +
                    self.ffn_layers = nn.ModuleList()
         | 
| 147 | 
            +
                    self.norm_layers_2 = nn.ModuleList()
         | 
| 148 | 
            +
                    for i in range(self.n_layers):
         | 
| 149 | 
            +
                        self.self_attn_layers.append(
         | 
| 150 | 
            +
                            MultiHeadAttention(
         | 
| 151 | 
            +
                                hidden_channels,
         | 
| 152 | 
            +
                                hidden_channels,
         | 
| 153 | 
            +
                                n_heads,
         | 
| 154 | 
            +
                                p_dropout=p_dropout,
         | 
| 155 | 
            +
                                proximal_bias=proximal_bias,
         | 
| 156 | 
            +
                                proximal_init=proximal_init,
         | 
| 157 | 
            +
                            )
         | 
| 158 | 
            +
                        )
         | 
| 159 | 
            +
                        self.norm_layers_0.append(LayerNorm(hidden_channels))
         | 
| 160 | 
            +
                        self.encdec_attn_layers.append(
         | 
| 161 | 
            +
                            MultiHeadAttention(
         | 
| 162 | 
            +
                                hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
         | 
| 163 | 
            +
                            )
         | 
| 164 | 
            +
                        )
         | 
| 165 | 
            +
                        self.norm_layers_1.append(LayerNorm(hidden_channels))
         | 
| 166 | 
            +
                        self.ffn_layers.append(
         | 
| 167 | 
            +
                            FFN(
         | 
| 168 | 
            +
                                hidden_channels,
         | 
| 169 | 
            +
                                hidden_channels,
         | 
| 170 | 
            +
                                filter_channels,
         | 
| 171 | 
            +
                                kernel_size,
         | 
| 172 | 
            +
                                p_dropout=p_dropout,
         | 
| 173 | 
            +
                                causal=True,
         | 
| 174 | 
            +
                            )
         | 
| 175 | 
            +
                        )
         | 
| 176 | 
            +
                        self.norm_layers_2.append(LayerNorm(hidden_channels))
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                def forward(self, x, x_mask, h, h_mask):
         | 
| 179 | 
            +
                    """
         | 
| 180 | 
            +
                    x: decoder input
         | 
| 181 | 
            +
                    h: encoder output
         | 
| 182 | 
            +
                    """
         | 
| 183 | 
            +
                    self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
         | 
| 184 | 
            +
                        device=x.device, dtype=x.dtype
         | 
| 185 | 
            +
                    )
         | 
| 186 | 
            +
                    encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
         | 
| 187 | 
            +
                    x = x * x_mask
         | 
| 188 | 
            +
                    for i in range(self.n_layers):
         | 
| 189 | 
            +
                        y = self.self_attn_layers[i](x, x, self_attn_mask)
         | 
| 190 | 
            +
                        y = self.drop(y)
         | 
| 191 | 
            +
                        x = self.norm_layers_0[i](x + y)
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                        y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
         | 
| 194 | 
            +
                        y = self.drop(y)
         | 
| 195 | 
            +
                        x = self.norm_layers_1[i](x + y)
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                        y = self.ffn_layers[i](x, x_mask)
         | 
| 198 | 
            +
                        y = self.drop(y)
         | 
| 199 | 
            +
                        x = self.norm_layers_2[i](x + y)
         | 
| 200 | 
            +
                    x = x * x_mask
         | 
| 201 | 
            +
                    return x
         | 
| 202 | 
            +
             | 
| 203 | 
            +
             | 
| 204 | 
            +
            class MultiHeadAttention(nn.Module):
         | 
| 205 | 
            +
                def __init__(
         | 
| 206 | 
            +
                    self,
         | 
| 207 | 
            +
                    channels,
         | 
| 208 | 
            +
                    out_channels,
         | 
| 209 | 
            +
                    n_heads,
         | 
| 210 | 
            +
                    p_dropout=0.0,
         | 
| 211 | 
            +
                    window_size=None,
         | 
| 212 | 
            +
                    heads_share=True,
         | 
| 213 | 
            +
                    block_length=None,
         | 
| 214 | 
            +
                    proximal_bias=False,
         | 
| 215 | 
            +
                    proximal_init=False,
         | 
| 216 | 
            +
                ):
         | 
| 217 | 
            +
                    super().__init__()
         | 
| 218 | 
            +
                    assert channels % n_heads == 0
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                    self.channels = channels
         | 
| 221 | 
            +
                    self.out_channels = out_channels
         | 
| 222 | 
            +
                    self.n_heads = n_heads
         | 
| 223 | 
            +
                    self.p_dropout = p_dropout
         | 
| 224 | 
            +
                    self.window_size = window_size
         | 
| 225 | 
            +
                    self.heads_share = heads_share
         | 
| 226 | 
            +
                    self.block_length = block_length
         | 
| 227 | 
            +
                    self.proximal_bias = proximal_bias
         | 
| 228 | 
            +
                    self.proximal_init = proximal_init
         | 
| 229 | 
            +
                    self.attn = None
         | 
| 230 | 
            +
             | 
| 231 | 
            +
                    self.k_channels = channels // n_heads
         | 
| 232 | 
            +
                    self.conv_q = nn.Conv1d(channels, channels, 1)
         | 
| 233 | 
            +
                    self.conv_k = nn.Conv1d(channels, channels, 1)
         | 
| 234 | 
            +
                    self.conv_v = nn.Conv1d(channels, channels, 1)
         | 
| 235 | 
            +
                    self.conv_o = nn.Conv1d(channels, out_channels, 1)
         | 
| 236 | 
            +
                    self.drop = nn.Dropout(p_dropout)
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                    if window_size is not None:
         | 
| 239 | 
            +
                        n_heads_rel = 1 if heads_share else n_heads
         | 
| 240 | 
            +
                        rel_stddev = self.k_channels**-0.5
         | 
| 241 | 
            +
                        self.emb_rel_k = nn.Parameter(
         | 
| 242 | 
            +
                            torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
         | 
| 243 | 
            +
                            * rel_stddev
         | 
| 244 | 
            +
                        )
         | 
| 245 | 
            +
                        self.emb_rel_v = nn.Parameter(
         | 
| 246 | 
            +
                            torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
         | 
| 247 | 
            +
                            * rel_stddev
         | 
| 248 | 
            +
                        )
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                    nn.init.xavier_uniform_(self.conv_q.weight)
         | 
| 251 | 
            +
                    nn.init.xavier_uniform_(self.conv_k.weight)
         | 
| 252 | 
            +
                    nn.init.xavier_uniform_(self.conv_v.weight)
         | 
| 253 | 
            +
                    if proximal_init:
         | 
| 254 | 
            +
                        with torch.no_grad():
         | 
| 255 | 
            +
                            self.conv_k.weight.copy_(self.conv_q.weight)
         | 
| 256 | 
            +
                            self.conv_k.bias.copy_(self.conv_q.bias)
         | 
| 257 | 
            +
             | 
| 258 | 
            +
                def forward(self, x, c, attn_mask=None):
         | 
| 259 | 
            +
                    q = self.conv_q(x)
         | 
| 260 | 
            +
                    k = self.conv_k(c)
         | 
| 261 | 
            +
                    v = self.conv_v(c)
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                    x, self.attn = self.attention(q, k, v, mask=attn_mask)
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                    x = self.conv_o(x)
         | 
| 266 | 
            +
                    return x
         | 
| 267 | 
            +
             | 
| 268 | 
            +
                def attention(self, query, key, value, mask=None):
         | 
| 269 | 
            +
                    # reshape [b, d, t] -> [b, n_h, t, d_k]
         | 
| 270 | 
            +
                    b, d, t_s, t_t = (*key.size(), query.size(2))
         | 
| 271 | 
            +
                    query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
         | 
| 272 | 
            +
                    key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
         | 
| 273 | 
            +
                    value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                    scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
         | 
| 276 | 
            +
                    if self.window_size is not None:
         | 
| 277 | 
            +
                        assert (
         | 
| 278 | 
            +
                            t_s == t_t
         | 
| 279 | 
            +
                        ), "Relative attention is only available for self-attention."
         | 
| 280 | 
            +
                        key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
         | 
| 281 | 
            +
                        rel_logits = self._matmul_with_relative_keys(
         | 
| 282 | 
            +
                            query / math.sqrt(self.k_channels), key_relative_embeddings
         | 
| 283 | 
            +
                        )
         | 
| 284 | 
            +
                        scores_local = self._relative_position_to_absolute_position(rel_logits)
         | 
| 285 | 
            +
                        scores = scores + scores_local
         | 
| 286 | 
            +
                    if self.proximal_bias:
         | 
| 287 | 
            +
                        assert t_s == t_t, "Proximal bias is only available for self-attention."
         | 
| 288 | 
            +
                        scores = scores + self._attention_bias_proximal(t_s).to(
         | 
| 289 | 
            +
                            device=scores.device, dtype=scores.dtype
         | 
| 290 | 
            +
                        )
         | 
| 291 | 
            +
                    if mask is not None:
         | 
| 292 | 
            +
                        scores = scores.masked_fill(mask == 0, -1e4)
         | 
| 293 | 
            +
                        if self.block_length is not None:
         | 
| 294 | 
            +
                            assert (
         | 
| 295 | 
            +
                                t_s == t_t
         | 
| 296 | 
            +
                            ), "Local attention is only available for self-attention."
         | 
| 297 | 
            +
                            block_mask = (
         | 
| 298 | 
            +
                                torch.ones_like(scores)
         | 
| 299 | 
            +
                                .triu(-self.block_length)
         | 
| 300 | 
            +
                                .tril(self.block_length)
         | 
| 301 | 
            +
                            )
         | 
| 302 | 
            +
                            scores = scores.masked_fill(block_mask == 0, -1e4)
         | 
| 303 | 
            +
                    p_attn = F.softmax(scores, dim=-1)  # [b, n_h, t_t, t_s]
         | 
| 304 | 
            +
                    p_attn = self.drop(p_attn)
         | 
| 305 | 
            +
                    output = torch.matmul(p_attn, value)
         | 
| 306 | 
            +
                    if self.window_size is not None:
         | 
| 307 | 
            +
                        relative_weights = self._absolute_position_to_relative_position(p_attn)
         | 
| 308 | 
            +
                        value_relative_embeddings = self._get_relative_embeddings(
         | 
| 309 | 
            +
                            self.emb_rel_v, t_s
         | 
| 310 | 
            +
                        )
         | 
| 311 | 
            +
                        output = output + self._matmul_with_relative_values(
         | 
| 312 | 
            +
                            relative_weights, value_relative_embeddings
         | 
| 313 | 
            +
                        )
         | 
| 314 | 
            +
                    output = (
         | 
| 315 | 
            +
                        output.transpose(2, 3).contiguous().view(b, d, t_t)
         | 
| 316 | 
            +
                    )  # [b, n_h, t_t, d_k] -> [b, d, t_t]
         | 
| 317 | 
            +
                    return output, p_attn
         | 
| 318 | 
            +
             | 
| 319 | 
            +
                def _matmul_with_relative_values(self, x, y):
         | 
| 320 | 
            +
                    """
         | 
| 321 | 
            +
                    x: [b, h, l, m]
         | 
| 322 | 
            +
                    y: [h or 1, m, d]
         | 
| 323 | 
            +
                    ret: [b, h, l, d]
         | 
| 324 | 
            +
                    """
         | 
| 325 | 
            +
                    ret = torch.matmul(x, y.unsqueeze(0))
         | 
| 326 | 
            +
                    return ret
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                def _matmul_with_relative_keys(self, x, y):
         | 
| 329 | 
            +
                    """
         | 
| 330 | 
            +
                    x: [b, h, l, d]
         | 
| 331 | 
            +
                    y: [h or 1, m, d]
         | 
| 332 | 
            +
                    ret: [b, h, l, m]
         | 
| 333 | 
            +
                    """
         | 
| 334 | 
            +
                    ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
         | 
| 335 | 
            +
                    return ret
         | 
| 336 | 
            +
             | 
| 337 | 
            +
                def _get_relative_embeddings(self, relative_embeddings, length):
         | 
| 338 | 
            +
                    2 * self.window_size + 1
         | 
| 339 | 
            +
                    # Pad first before slice to avoid using cond ops.
         | 
| 340 | 
            +
                    pad_length = max(length - (self.window_size + 1), 0)
         | 
| 341 | 
            +
                    slice_start_position = max((self.window_size + 1) - length, 0)
         | 
| 342 | 
            +
                    slice_end_position = slice_start_position + 2 * length - 1
         | 
| 343 | 
            +
                    if pad_length > 0:
         | 
| 344 | 
            +
                        padded_relative_embeddings = F.pad(
         | 
| 345 | 
            +
                            relative_embeddings,
         | 
| 346 | 
            +
                            commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
         | 
| 347 | 
            +
                        )
         | 
| 348 | 
            +
                    else:
         | 
| 349 | 
            +
                        padded_relative_embeddings = relative_embeddings
         | 
| 350 | 
            +
                    used_relative_embeddings = padded_relative_embeddings[
         | 
| 351 | 
            +
                        :, slice_start_position:slice_end_position
         | 
| 352 | 
            +
                    ]
         | 
| 353 | 
            +
                    return used_relative_embeddings
         | 
| 354 | 
            +
             | 
| 355 | 
            +
                def _relative_position_to_absolute_position(self, x):
         | 
| 356 | 
            +
                    """
         | 
| 357 | 
            +
                    x: [b, h, l, 2*l-1]
         | 
| 358 | 
            +
                    ret: [b, h, l, l]
         | 
| 359 | 
            +
                    """
         | 
| 360 | 
            +
                    batch, heads, length, _ = x.size()
         | 
| 361 | 
            +
                    # Concat columns of pad to shift from relative to absolute indexing.
         | 
| 362 | 
            +
                    x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
         | 
| 363 | 
            +
             | 
| 364 | 
            +
                    # Concat extra elements so to add up to shape (len+1, 2*len-1).
         | 
| 365 | 
            +
                    x_flat = x.view([batch, heads, length * 2 * length])
         | 
| 366 | 
            +
                    x_flat = F.pad(
         | 
| 367 | 
            +
                        x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
         | 
| 368 | 
            +
                    )
         | 
| 369 | 
            +
             | 
| 370 | 
            +
                    # Reshape and slice out the padded elements.
         | 
| 371 | 
            +
                    x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
         | 
| 372 | 
            +
                        :, :, :length, length - 1 :
         | 
| 373 | 
            +
                    ]
         | 
| 374 | 
            +
                    return x_final
         | 
| 375 | 
            +
             | 
| 376 | 
            +
                def _absolute_position_to_relative_position(self, x):
         | 
| 377 | 
            +
                    """
         | 
| 378 | 
            +
                    x: [b, h, l, l]
         | 
| 379 | 
            +
                    ret: [b, h, l, 2*l-1]
         | 
| 380 | 
            +
                    """
         | 
| 381 | 
            +
                    batch, heads, length, _ = x.size()
         | 
| 382 | 
            +
                    # pad along column
         | 
| 383 | 
            +
                    x = F.pad(
         | 
| 384 | 
            +
                        x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
         | 
| 385 | 
            +
                    )
         | 
| 386 | 
            +
                    x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
         | 
| 387 | 
            +
                    # add 0's in the beginning that will skew the elements after reshape
         | 
| 388 | 
            +
                    x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
         | 
| 389 | 
            +
                    x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
         | 
| 390 | 
            +
                    return x_final
         | 
| 391 | 
            +
             | 
| 392 | 
            +
                def _attention_bias_proximal(self, length):
         | 
| 393 | 
            +
                    """Bias for self-attention to encourage attention to close positions.
         | 
| 394 | 
            +
                    Args:
         | 
| 395 | 
            +
                      length: an integer scalar.
         | 
| 396 | 
            +
                    Returns:
         | 
| 397 | 
            +
                      a Tensor with shape [1, 1, length, length]
         | 
| 398 | 
            +
                    """
         | 
| 399 | 
            +
                    r = torch.arange(length, dtype=torch.float32)
         | 
| 400 | 
            +
                    diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
         | 
| 401 | 
            +
                    return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
         | 
| 402 | 
            +
             | 
| 403 | 
            +
             | 
| 404 | 
            +
            class FFN(nn.Module):
         | 
| 405 | 
            +
                def __init__(
         | 
| 406 | 
            +
                    self,
         | 
| 407 | 
            +
                    in_channels,
         | 
| 408 | 
            +
                    out_channels,
         | 
| 409 | 
            +
                    filter_channels,
         | 
| 410 | 
            +
                    kernel_size,
         | 
| 411 | 
            +
                    p_dropout=0.0,
         | 
| 412 | 
            +
                    activation=None,
         | 
| 413 | 
            +
                    causal=False,
         | 
| 414 | 
            +
                ):
         | 
| 415 | 
            +
                    super().__init__()
         | 
| 416 | 
            +
                    self.in_channels = in_channels
         | 
| 417 | 
            +
                    self.out_channels = out_channels
         | 
| 418 | 
            +
                    self.filter_channels = filter_channels
         | 
| 419 | 
            +
                    self.kernel_size = kernel_size
         | 
| 420 | 
            +
                    self.p_dropout = p_dropout
         | 
| 421 | 
            +
                    self.activation = activation
         | 
| 422 | 
            +
                    self.causal = causal
         | 
| 423 | 
            +
             | 
| 424 | 
            +
                    if causal:
         | 
| 425 | 
            +
                        self.padding = self._causal_padding
         | 
| 426 | 
            +
                    else:
         | 
| 427 | 
            +
                        self.padding = self._same_padding
         | 
| 428 | 
            +
             | 
| 429 | 
            +
                    self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
         | 
| 430 | 
            +
                    self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
         | 
| 431 | 
            +
                    self.drop = nn.Dropout(p_dropout)
         | 
| 432 | 
            +
             | 
| 433 | 
            +
                def forward(self, x, x_mask):
         | 
| 434 | 
            +
                    x = self.conv_1(self.padding(x * x_mask))
         | 
| 435 | 
            +
                    if self.activation == "gelu":
         | 
| 436 | 
            +
                        x = x * torch.sigmoid(1.702 * x)
         | 
| 437 | 
            +
                    else:
         | 
| 438 | 
            +
                        x = torch.relu(x)
         | 
| 439 | 
            +
                    x = self.drop(x)
         | 
| 440 | 
            +
                    x = self.conv_2(self.padding(x * x_mask))
         | 
| 441 | 
            +
                    return x * x_mask
         | 
| 442 | 
            +
             | 
| 443 | 
            +
                def _causal_padding(self, x):
         | 
| 444 | 
            +
                    if self.kernel_size == 1:
         | 
| 445 | 
            +
                        return x
         | 
| 446 | 
            +
                    pad_l = self.kernel_size - 1
         | 
| 447 | 
            +
                    pad_r = 0
         | 
| 448 | 
            +
                    padding = [[0, 0], [0, 0], [pad_l, pad_r]]
         | 
| 449 | 
            +
                    x = F.pad(x, commons.convert_pad_shape(padding))
         | 
| 450 | 
            +
                    return x
         | 
| 451 | 
            +
             | 
| 452 | 
            +
                def _same_padding(self, x):
         | 
| 453 | 
            +
                    if self.kernel_size == 1:
         | 
| 454 | 
            +
                        return x
         | 
| 455 | 
            +
                    pad_l = (self.kernel_size - 1) // 2
         | 
| 456 | 
            +
                    pad_r = self.kernel_size // 2
         | 
| 457 | 
            +
                    padding = [[0, 0], [0, 0], [pad_l, pad_r]]
         | 
| 458 | 
            +
                    x = F.pad(x, commons.convert_pad_shape(padding))
         | 
| 459 | 
            +
                    return x
         | 
    	
        melo/commons.py
    ADDED
    
    | @@ -0,0 +1,160 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import math
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            from torch.nn import functional as F
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            def init_weights(m, mean=0.0, std=0.01):
         | 
| 7 | 
            +
                classname = m.__class__.__name__
         | 
| 8 | 
            +
                if classname.find("Conv") != -1:
         | 
| 9 | 
            +
                    m.weight.data.normal_(mean, std)
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            def get_padding(kernel_size, dilation=1):
         | 
| 13 | 
            +
                return int((kernel_size * dilation - dilation) / 2)
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            def convert_pad_shape(pad_shape):
         | 
| 17 | 
            +
                layer = pad_shape[::-1]
         | 
| 18 | 
            +
                pad_shape = [item for sublist in layer for item in sublist]
         | 
| 19 | 
            +
                return pad_shape
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            def intersperse(lst, item):
         | 
| 23 | 
            +
                result = [item] * (len(lst) * 2 + 1)
         | 
| 24 | 
            +
                result[1::2] = lst
         | 
| 25 | 
            +
                return result
         | 
| 26 | 
            +
             | 
| 27 | 
            +
             | 
| 28 | 
            +
            def kl_divergence(m_p, logs_p, m_q, logs_q):
         | 
| 29 | 
            +
                """KL(P||Q)"""
         | 
| 30 | 
            +
                kl = (logs_q - logs_p) - 0.5
         | 
| 31 | 
            +
                kl += (
         | 
| 32 | 
            +
                    0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
         | 
| 33 | 
            +
                )
         | 
| 34 | 
            +
                return kl
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            def rand_gumbel(shape):
         | 
| 38 | 
            +
                """Sample from the Gumbel distribution, protect from overflows."""
         | 
| 39 | 
            +
                uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
         | 
| 40 | 
            +
                return -torch.log(-torch.log(uniform_samples))
         | 
| 41 | 
            +
             | 
| 42 | 
            +
             | 
| 43 | 
            +
            def rand_gumbel_like(x):
         | 
| 44 | 
            +
                g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
         | 
| 45 | 
            +
                return g
         | 
| 46 | 
            +
             | 
| 47 | 
            +
             | 
| 48 | 
            +
            def slice_segments(x, ids_str, segment_size=4):
         | 
| 49 | 
            +
                ret = torch.zeros_like(x[:, :, :segment_size])
         | 
| 50 | 
            +
                for i in range(x.size(0)):
         | 
| 51 | 
            +
                    idx_str = ids_str[i]
         | 
| 52 | 
            +
                    idx_end = idx_str + segment_size
         | 
| 53 | 
            +
                    ret[i] = x[i, :, idx_str:idx_end]
         | 
| 54 | 
            +
                return ret
         | 
| 55 | 
            +
             | 
| 56 | 
            +
             | 
| 57 | 
            +
            def rand_slice_segments(x, x_lengths=None, segment_size=4):
         | 
| 58 | 
            +
                b, d, t = x.size()
         | 
| 59 | 
            +
                if x_lengths is None:
         | 
| 60 | 
            +
                    x_lengths = t
         | 
| 61 | 
            +
                ids_str_max = x_lengths - segment_size + 1
         | 
| 62 | 
            +
                ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
         | 
| 63 | 
            +
                ret = slice_segments(x, ids_str, segment_size)
         | 
| 64 | 
            +
                return ret, ids_str
         | 
| 65 | 
            +
             | 
| 66 | 
            +
             | 
| 67 | 
            +
            def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
         | 
| 68 | 
            +
                position = torch.arange(length, dtype=torch.float)
         | 
| 69 | 
            +
                num_timescales = channels // 2
         | 
| 70 | 
            +
                log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
         | 
| 71 | 
            +
                    num_timescales - 1
         | 
| 72 | 
            +
                )
         | 
| 73 | 
            +
                inv_timescales = min_timescale * torch.exp(
         | 
| 74 | 
            +
                    torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
         | 
| 75 | 
            +
                )
         | 
| 76 | 
            +
                scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
         | 
| 77 | 
            +
                signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
         | 
| 78 | 
            +
                signal = F.pad(signal, [0, 0, 0, channels % 2])
         | 
| 79 | 
            +
                signal = signal.view(1, channels, length)
         | 
| 80 | 
            +
                return signal
         | 
| 81 | 
            +
             | 
| 82 | 
            +
             | 
| 83 | 
            +
            def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
         | 
| 84 | 
            +
                b, channels, length = x.size()
         | 
| 85 | 
            +
                signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
         | 
| 86 | 
            +
                return x + signal.to(dtype=x.dtype, device=x.device)
         | 
| 87 | 
            +
             | 
| 88 | 
            +
             | 
| 89 | 
            +
            def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
         | 
| 90 | 
            +
                b, channels, length = x.size()
         | 
| 91 | 
            +
                signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
         | 
| 92 | 
            +
                return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
         | 
| 93 | 
            +
             | 
| 94 | 
            +
             | 
| 95 | 
            +
            def subsequent_mask(length):
         | 
| 96 | 
            +
                mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
         | 
| 97 | 
            +
                return mask
         | 
| 98 | 
            +
             | 
| 99 | 
            +
             | 
| 100 | 
            +
            @torch.jit.script
         | 
| 101 | 
            +
            def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
         | 
| 102 | 
            +
                n_channels_int = n_channels[0]
         | 
| 103 | 
            +
                in_act = input_a + input_b
         | 
| 104 | 
            +
                t_act = torch.tanh(in_act[:, :n_channels_int, :])
         | 
| 105 | 
            +
                s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
         | 
| 106 | 
            +
                acts = t_act * s_act
         | 
| 107 | 
            +
                return acts
         | 
| 108 | 
            +
             | 
| 109 | 
            +
             | 
| 110 | 
            +
            def convert_pad_shape(pad_shape):
         | 
| 111 | 
            +
                layer = pad_shape[::-1]
         | 
| 112 | 
            +
                pad_shape = [item for sublist in layer for item in sublist]
         | 
| 113 | 
            +
                return pad_shape
         | 
| 114 | 
            +
             | 
| 115 | 
            +
             | 
| 116 | 
            +
            def shift_1d(x):
         | 
| 117 | 
            +
                x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
         | 
| 118 | 
            +
                return x
         | 
| 119 | 
            +
             | 
| 120 | 
            +
             | 
| 121 | 
            +
            def sequence_mask(length, max_length=None):
         | 
| 122 | 
            +
                if max_length is None:
         | 
| 123 | 
            +
                    max_length = length.max()
         | 
| 124 | 
            +
                x = torch.arange(max_length, dtype=length.dtype, device=length.device)
         | 
| 125 | 
            +
                return x.unsqueeze(0) < length.unsqueeze(1)
         | 
| 126 | 
            +
             | 
| 127 | 
            +
             | 
| 128 | 
            +
            def generate_path(duration, mask):
         | 
| 129 | 
            +
                """
         | 
| 130 | 
            +
                duration: [b, 1, t_x]
         | 
| 131 | 
            +
                mask: [b, 1, t_y, t_x]
         | 
| 132 | 
            +
                """
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                b, _, t_y, t_x = mask.shape
         | 
| 135 | 
            +
                cum_duration = torch.cumsum(duration, -1)
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                cum_duration_flat = cum_duration.view(b * t_x)
         | 
| 138 | 
            +
                path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
         | 
| 139 | 
            +
                path = path.view(b, t_x, t_y)
         | 
| 140 | 
            +
                path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
         | 
| 141 | 
            +
                path = path.unsqueeze(1).transpose(2, 3) * mask
         | 
| 142 | 
            +
                return path
         | 
| 143 | 
            +
             | 
| 144 | 
            +
             | 
| 145 | 
            +
            def clip_grad_value_(parameters, clip_value, norm_type=2):
         | 
| 146 | 
            +
                if isinstance(parameters, torch.Tensor):
         | 
| 147 | 
            +
                    parameters = [parameters]
         | 
| 148 | 
            +
                parameters = list(filter(lambda p: p.grad is not None, parameters))
         | 
| 149 | 
            +
                norm_type = float(norm_type)
         | 
| 150 | 
            +
                if clip_value is not None:
         | 
| 151 | 
            +
                    clip_value = float(clip_value)
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                total_norm = 0
         | 
| 154 | 
            +
                for p in parameters:
         | 
| 155 | 
            +
                    param_norm = p.grad.data.norm(norm_type)
         | 
| 156 | 
            +
                    total_norm += param_norm.item() ** norm_type
         | 
| 157 | 
            +
                    if clip_value is not None:
         | 
| 158 | 
            +
                        p.grad.data.clamp_(min=-clip_value, max=clip_value)
         | 
| 159 | 
            +
                total_norm = total_norm ** (1.0 / norm_type)
         | 
| 160 | 
            +
                return total_norm
         | 
    	
        melo/download_utils.py
    ADDED
    
    | @@ -0,0 +1,47 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            from . import utils
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            DOWNLOAD_CKPT_URLS = {
         | 
| 6 | 
            +
                'EN': 'https://myshell-public-repo-hosting.s3.amazonaws.com/openvoice/basespeakers/EN/checkpoint.pth',
         | 
| 7 | 
            +
                'EN_V2': 'https://myshell-public-repo-hosting.s3.amazonaws.com/openvoice/basespeakers/EN_V2/checkpoint.pth',
         | 
| 8 | 
            +
                'FR': 'https://myshell-public-repo-hosting.s3.amazonaws.com/openvoice/basespeakers/FR/checkpoint.pth',
         | 
| 9 | 
            +
                'JP': 'https://myshell-public-repo-hosting.s3.amazonaws.com/openvoice/basespeakers/JP/checkpoint.pth',
         | 
| 10 | 
            +
                'ES': 'https://myshell-public-repo-hosting.s3.amazonaws.com/openvoice/basespeakers/ES/checkpoint.pth',
         | 
| 11 | 
            +
                'ZH': 'https://myshell-public-repo-hosting.s3.amazonaws.com/openvoice/basespeakers/ZH/checkpoint.pth',
         | 
| 12 | 
            +
                'KR': 'https://myshell-public-repo-hosting.s3.amazonaws.com/openvoice/basespeakers/KR/checkpoint.pth',
         | 
| 13 | 
            +
            }
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            DOWNLOAD_CONFIG_URLS = {
         | 
| 16 | 
            +
                'EN': 'https://myshell-public-repo-hosting.s3.amazonaws.com/openvoice/basespeakers/EN/config.json',
         | 
| 17 | 
            +
                'EN_V2': 'https://myshell-public-repo-hosting.s3.amazonaws.com/openvoice/basespeakers/EN_V2/config.json',
         | 
| 18 | 
            +
                'FR': 'https://myshell-public-repo-hosting.s3.amazonaws.com/openvoice/basespeakers/FR/config.json',
         | 
| 19 | 
            +
                'JP': 'https://myshell-public-repo-hosting.s3.amazonaws.com/openvoice/basespeakers/JP/config.json',
         | 
| 20 | 
            +
                'ES': 'https://myshell-public-repo-hosting.s3.amazonaws.com/openvoice/basespeakers/ES/config.json',
         | 
| 21 | 
            +
                'ZH': 'https://myshell-public-repo-hosting.s3.amazonaws.com/openvoice/basespeakers/ZH/config.json',
         | 
| 22 | 
            +
                'KR': 'https://myshell-public-repo-hosting.s3.amazonaws.com/openvoice/basespeakers/KR/config.json',
         | 
| 23 | 
            +
            }
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            def load_or_download_config(locale):
         | 
| 26 | 
            +
                language = locale.split('-')[0].upper()
         | 
| 27 | 
            +
                assert language in DOWNLOAD_CONFIG_URLS
         | 
| 28 | 
            +
                config_path = os.path.expanduser(f'~/.local/share/openvoice/basespeakers/{language}/config.json')
         | 
| 29 | 
            +
                try:
         | 
| 30 | 
            +
                    return utils.get_hparams_from_file(config_path)
         | 
| 31 | 
            +
                except:
         | 
| 32 | 
            +
                    # download
         | 
| 33 | 
            +
                    os.makedirs(os.path.dirname(config_path), exist_ok=True)
         | 
| 34 | 
            +
                    os.system(f'wget {DOWNLOAD_CONFIG_URLS[language]} -O {config_path}')
         | 
| 35 | 
            +
                return utils.get_hparams_from_file(config_path)
         | 
| 36 | 
            +
             | 
| 37 | 
            +
            def load_or_download_model(locale, device):
         | 
| 38 | 
            +
                language = locale.split('-')[0].upper()
         | 
| 39 | 
            +
                assert language in DOWNLOAD_CKPT_URLS
         | 
| 40 | 
            +
                ckpt_path = os.path.expanduser(f'~/.local/share/openvoice/basespeakers/{language}/checkpoint.pth')
         | 
| 41 | 
            +
                try:
         | 
| 42 | 
            +
                    return torch.load(ckpt_path, map_location=device)
         | 
| 43 | 
            +
                except:
         | 
| 44 | 
            +
                    # download
         | 
| 45 | 
            +
                    os.makedirs(os.path.dirname(ckpt_path), exist_ok=True)
         | 
| 46 | 
            +
                    os.system(f'wget {DOWNLOAD_CKPT_URLS[language]} -O {ckpt_path}')
         | 
| 47 | 
            +
                return torch.load(ckpt_path, map_location=device)
         | 
    	
        melo/mel_processing.py
    ADDED
    
    | @@ -0,0 +1,174 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.utils.data
         | 
| 3 | 
            +
            import librosa
         | 
| 4 | 
            +
            from librosa.filters import mel as librosa_mel_fn
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            MAX_WAV_VALUE = 32768.0
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
         | 
| 10 | 
            +
                """
         | 
| 11 | 
            +
                PARAMS
         | 
| 12 | 
            +
                ------
         | 
| 13 | 
            +
                C: compression factor
         | 
| 14 | 
            +
                """
         | 
| 15 | 
            +
                return torch.log(torch.clamp(x, min=clip_val) * C)
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            def dynamic_range_decompression_torch(x, C=1):
         | 
| 19 | 
            +
                """
         | 
| 20 | 
            +
                PARAMS
         | 
| 21 | 
            +
                ------
         | 
| 22 | 
            +
                C: compression factor used to compress
         | 
| 23 | 
            +
                """
         | 
| 24 | 
            +
                return torch.exp(x) / C
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            def spectral_normalize_torch(magnitudes):
         | 
| 28 | 
            +
                output = dynamic_range_compression_torch(magnitudes)
         | 
| 29 | 
            +
                return output
         | 
| 30 | 
            +
             | 
| 31 | 
            +
             | 
| 32 | 
            +
            def spectral_de_normalize_torch(magnitudes):
         | 
| 33 | 
            +
                output = dynamic_range_decompression_torch(magnitudes)
         | 
| 34 | 
            +
                return output
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            mel_basis = {}
         | 
| 38 | 
            +
            hann_window = {}
         | 
| 39 | 
            +
             | 
| 40 | 
            +
             | 
| 41 | 
            +
            def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
         | 
| 42 | 
            +
                if torch.min(y) < -1.1:
         | 
| 43 | 
            +
                    print("min value is ", torch.min(y))
         | 
| 44 | 
            +
                if torch.max(y) > 1.1:
         | 
| 45 | 
            +
                    print("max value is ", torch.max(y))
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                global hann_window
         | 
| 48 | 
            +
                dtype_device = str(y.dtype) + "_" + str(y.device)
         | 
| 49 | 
            +
                wnsize_dtype_device = str(win_size) + "_" + dtype_device
         | 
| 50 | 
            +
                if wnsize_dtype_device not in hann_window:
         | 
| 51 | 
            +
                    hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
         | 
| 52 | 
            +
                        dtype=y.dtype, device=y.device
         | 
| 53 | 
            +
                    )
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                y = torch.nn.functional.pad(
         | 
| 56 | 
            +
                    y.unsqueeze(1),
         | 
| 57 | 
            +
                    (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
         | 
| 58 | 
            +
                    mode="reflect",
         | 
| 59 | 
            +
                )
         | 
| 60 | 
            +
                y = y.squeeze(1)
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                spec = torch.stft(
         | 
| 63 | 
            +
                    y,
         | 
| 64 | 
            +
                    n_fft,
         | 
| 65 | 
            +
                    hop_length=hop_size,
         | 
| 66 | 
            +
                    win_length=win_size,
         | 
| 67 | 
            +
                    window=hann_window[wnsize_dtype_device],
         | 
| 68 | 
            +
                    center=center,
         | 
| 69 | 
            +
                    pad_mode="reflect",
         | 
| 70 | 
            +
                    normalized=False,
         | 
| 71 | 
            +
                    onesided=True,
         | 
| 72 | 
            +
                    return_complex=False,
         | 
| 73 | 
            +
                )
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
         | 
| 76 | 
            +
                return spec
         | 
| 77 | 
            +
             | 
| 78 | 
            +
             | 
| 79 | 
            +
            def spectrogram_torch_conv(y, n_fft, sampling_rate, hop_size, win_size, center=False):
         | 
| 80 | 
            +
                global hann_window
         | 
| 81 | 
            +
                dtype_device = str(y.dtype) + '_' + str(y.device)
         | 
| 82 | 
            +
                wnsize_dtype_device = str(win_size) + '_' + dtype_device
         | 
| 83 | 
            +
                if wnsize_dtype_device not in hann_window:
         | 
| 84 | 
            +
                    hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
         | 
| 87 | 
            +
                
         | 
| 88 | 
            +
                # ******************** original ************************#
         | 
| 89 | 
            +
                # y = y.squeeze(1)
         | 
| 90 | 
            +
                # spec1 = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
         | 
| 91 | 
            +
                #                   center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                # ******************** ConvSTFT ************************#
         | 
| 94 | 
            +
                freq_cutoff = n_fft // 2 + 1
         | 
| 95 | 
            +
                fourier_basis = torch.view_as_real(torch.fft.fft(torch.eye(n_fft)))
         | 
| 96 | 
            +
                forward_basis = fourier_basis[:freq_cutoff].permute(2, 0, 1).reshape(-1, 1, fourier_basis.shape[1])
         | 
| 97 | 
            +
                forward_basis = forward_basis * torch.as_tensor(librosa.util.pad_center(torch.hann_window(win_size), size=n_fft)).float()
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                import torch.nn.functional as F
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                # if center:
         | 
| 102 | 
            +
                #     signal = F.pad(y[:, None, None, :], (n_fft // 2, n_fft // 2, 0, 0), mode = 'reflect').squeeze(1)
         | 
| 103 | 
            +
                assert center is False
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                forward_transform_squared = F.conv1d(y, forward_basis.to(y.device), stride = hop_size)
         | 
| 106 | 
            +
                spec2 = torch.stack([forward_transform_squared[:, :freq_cutoff, :], forward_transform_squared[:, freq_cutoff:, :]], dim = -1)
         | 
| 107 | 
            +
             | 
| 108 | 
            +
             | 
| 109 | 
            +
                # ******************** Verification ************************#
         | 
| 110 | 
            +
                spec1 = torch.stft(y.squeeze(1), n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
         | 
| 111 | 
            +
                                  center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
         | 
| 112 | 
            +
                assert torch.allclose(spec1, spec2, atol=1e-4)
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                spec = torch.sqrt(spec2.pow(2).sum(-1) + 1e-6)
         | 
| 115 | 
            +
                return spec
         | 
| 116 | 
            +
             | 
| 117 | 
            +
             | 
| 118 | 
            +
            def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
         | 
| 119 | 
            +
                global mel_basis
         | 
| 120 | 
            +
                dtype_device = str(spec.dtype) + "_" + str(spec.device)
         | 
| 121 | 
            +
                fmax_dtype_device = str(fmax) + "_" + dtype_device
         | 
| 122 | 
            +
                if fmax_dtype_device not in mel_basis:
         | 
| 123 | 
            +
                    mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
         | 
| 124 | 
            +
                    mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
         | 
| 125 | 
            +
                        dtype=spec.dtype, device=spec.device
         | 
| 126 | 
            +
                    )
         | 
| 127 | 
            +
                spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
         | 
| 128 | 
            +
                spec = spectral_normalize_torch(spec)
         | 
| 129 | 
            +
                return spec
         | 
| 130 | 
            +
             | 
| 131 | 
            +
             | 
| 132 | 
            +
            def mel_spectrogram_torch(
         | 
| 133 | 
            +
                y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
         | 
| 134 | 
            +
            ):
         | 
| 135 | 
            +
                global mel_basis, hann_window
         | 
| 136 | 
            +
                dtype_device = str(y.dtype) + "_" + str(y.device)
         | 
| 137 | 
            +
                fmax_dtype_device = str(fmax) + "_" + dtype_device
         | 
| 138 | 
            +
                wnsize_dtype_device = str(win_size) + "_" + dtype_device
         | 
| 139 | 
            +
                if fmax_dtype_device not in mel_basis:
         | 
| 140 | 
            +
                    mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
         | 
| 141 | 
            +
                    mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
         | 
| 142 | 
            +
                        dtype=y.dtype, device=y.device
         | 
| 143 | 
            +
                    )
         | 
| 144 | 
            +
                if wnsize_dtype_device not in hann_window:
         | 
| 145 | 
            +
                    hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
         | 
| 146 | 
            +
                        dtype=y.dtype, device=y.device
         | 
| 147 | 
            +
                    )
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                y = torch.nn.functional.pad(
         | 
| 150 | 
            +
                    y.unsqueeze(1),
         | 
| 151 | 
            +
                    (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
         | 
| 152 | 
            +
                    mode="reflect",
         | 
| 153 | 
            +
                )
         | 
| 154 | 
            +
                y = y.squeeze(1)
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                spec = torch.stft(
         | 
| 157 | 
            +
                    y,
         | 
| 158 | 
            +
                    n_fft,
         | 
| 159 | 
            +
                    hop_length=hop_size,
         | 
| 160 | 
            +
                    win_length=win_size,
         | 
| 161 | 
            +
                    window=hann_window[wnsize_dtype_device],
         | 
| 162 | 
            +
                    center=center,
         | 
| 163 | 
            +
                    pad_mode="reflect",
         | 
| 164 | 
            +
                    normalized=False,
         | 
| 165 | 
            +
                    onesided=True,
         | 
| 166 | 
            +
                    return_complex=False,
         | 
| 167 | 
            +
                )
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
         | 
| 172 | 
            +
                spec = spectral_normalize_torch(spec)
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                return spec
         | 
    	
        melo/models.py
    ADDED
    
    | @@ -0,0 +1,1038 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import math
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            from torch import nn
         | 
| 4 | 
            +
            from torch.nn import functional as F
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from . import commons
         | 
| 7 | 
            +
            from . import modules
         | 
| 8 | 
            +
            from . import attentions
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from torch.nn import Conv1d, ConvTranspose1d, Conv2d
         | 
| 11 | 
            +
            from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            from .commons import init_weights, get_padding
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            class DurationDiscriminator(nn.Module):  # vits2
         | 
| 17 | 
            +
                def __init__(
         | 
| 18 | 
            +
                    self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
         | 
| 19 | 
            +
                ):
         | 
| 20 | 
            +
                    super().__init__()
         | 
| 21 | 
            +
                    self.in_channels = in_channels
         | 
| 22 | 
            +
                    self.filter_channels = filter_channels
         | 
| 23 | 
            +
                    self.kernel_size = kernel_size
         | 
| 24 | 
            +
                    self.p_dropout = p_dropout
         | 
| 25 | 
            +
                    self.gin_channels = gin_channels
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                    self.drop = nn.Dropout(p_dropout)
         | 
| 28 | 
            +
                    self.conv_1 = nn.Conv1d(
         | 
| 29 | 
            +
                        in_channels, filter_channels, kernel_size, padding=kernel_size // 2
         | 
| 30 | 
            +
                    )
         | 
| 31 | 
            +
                    self.norm_1 = modules.LayerNorm(filter_channels)
         | 
| 32 | 
            +
                    self.conv_2 = nn.Conv1d(
         | 
| 33 | 
            +
                        filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
         | 
| 34 | 
            +
                    )
         | 
| 35 | 
            +
                    self.norm_2 = modules.LayerNorm(filter_channels)
         | 
| 36 | 
            +
                    self.dur_proj = nn.Conv1d(1, filter_channels, 1)
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                    self.pre_out_conv_1 = nn.Conv1d(
         | 
| 39 | 
            +
                        2 * filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
         | 
| 40 | 
            +
                    )
         | 
| 41 | 
            +
                    self.pre_out_norm_1 = modules.LayerNorm(filter_channels)
         | 
| 42 | 
            +
                    self.pre_out_conv_2 = nn.Conv1d(
         | 
| 43 | 
            +
                        filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
         | 
| 44 | 
            +
                    )
         | 
| 45 | 
            +
                    self.pre_out_norm_2 = modules.LayerNorm(filter_channels)
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                    if gin_channels != 0:
         | 
| 48 | 
            +
                        self.cond = nn.Conv1d(gin_channels, in_channels, 1)
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                    self.output_layer = nn.Sequential(nn.Linear(filter_channels, 1), nn.Sigmoid())
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                def forward_probability(self, x, x_mask, dur, g=None):
         | 
| 53 | 
            +
                    dur = self.dur_proj(dur)
         | 
| 54 | 
            +
                    x = torch.cat([x, dur], dim=1)
         | 
| 55 | 
            +
                    x = self.pre_out_conv_1(x * x_mask)
         | 
| 56 | 
            +
                    x = torch.relu(x)
         | 
| 57 | 
            +
                    x = self.pre_out_norm_1(x)
         | 
| 58 | 
            +
                    x = self.drop(x)
         | 
| 59 | 
            +
                    x = self.pre_out_conv_2(x * x_mask)
         | 
| 60 | 
            +
                    x = torch.relu(x)
         | 
| 61 | 
            +
                    x = self.pre_out_norm_2(x)
         | 
| 62 | 
            +
                    x = self.drop(x)
         | 
| 63 | 
            +
                    x = x * x_mask
         | 
| 64 | 
            +
                    x = x.transpose(1, 2)
         | 
| 65 | 
            +
                    output_prob = self.output_layer(x)
         | 
| 66 | 
            +
                    return output_prob
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                def forward(self, x, x_mask, dur_r, dur_hat, g=None):
         | 
| 69 | 
            +
                    x = torch.detach(x)
         | 
| 70 | 
            +
                    if g is not None:
         | 
| 71 | 
            +
                        g = torch.detach(g)
         | 
| 72 | 
            +
                        x = x + self.cond(g)
         | 
| 73 | 
            +
                    x = self.conv_1(x * x_mask)
         | 
| 74 | 
            +
                    x = torch.relu(x)
         | 
| 75 | 
            +
                    x = self.norm_1(x)
         | 
| 76 | 
            +
                    x = self.drop(x)
         | 
| 77 | 
            +
                    x = self.conv_2(x * x_mask)
         | 
| 78 | 
            +
                    x = torch.relu(x)
         | 
| 79 | 
            +
                    x = self.norm_2(x)
         | 
| 80 | 
            +
                    x = self.drop(x)
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                    output_probs = []
         | 
| 83 | 
            +
                    for dur in [dur_r, dur_hat]:
         | 
| 84 | 
            +
                        output_prob = self.forward_probability(x, x_mask, dur, g)
         | 
| 85 | 
            +
                        output_probs.append(output_prob)
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                    return output_probs
         | 
| 88 | 
            +
             | 
| 89 | 
            +
             | 
| 90 | 
            +
            class TransformerCouplingBlock(nn.Module):
         | 
| 91 | 
            +
                def __init__(
         | 
| 92 | 
            +
                    self,
         | 
| 93 | 
            +
                    channels,
         | 
| 94 | 
            +
                    hidden_channels,
         | 
| 95 | 
            +
                    filter_channels,
         | 
| 96 | 
            +
                    n_heads,
         | 
| 97 | 
            +
                    n_layers,
         | 
| 98 | 
            +
                    kernel_size,
         | 
| 99 | 
            +
                    p_dropout,
         | 
| 100 | 
            +
                    n_flows=4,
         | 
| 101 | 
            +
                    gin_channels=0,
         | 
| 102 | 
            +
                    share_parameter=False,
         | 
| 103 | 
            +
                ):
         | 
| 104 | 
            +
                    super().__init__()
         | 
| 105 | 
            +
                    self.channels = channels
         | 
| 106 | 
            +
                    self.hidden_channels = hidden_channels
         | 
| 107 | 
            +
                    self.kernel_size = kernel_size
         | 
| 108 | 
            +
                    self.n_layers = n_layers
         | 
| 109 | 
            +
                    self.n_flows = n_flows
         | 
| 110 | 
            +
                    self.gin_channels = gin_channels
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                    self.flows = nn.ModuleList()
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                    self.wn = (
         | 
| 115 | 
            +
                        attentions.FFT(
         | 
| 116 | 
            +
                            hidden_channels,
         | 
| 117 | 
            +
                            filter_channels,
         | 
| 118 | 
            +
                            n_heads,
         | 
| 119 | 
            +
                            n_layers,
         | 
| 120 | 
            +
                            kernel_size,
         | 
| 121 | 
            +
                            p_dropout,
         | 
| 122 | 
            +
                            isflow=True,
         | 
| 123 | 
            +
                            gin_channels=self.gin_channels,
         | 
| 124 | 
            +
                        )
         | 
| 125 | 
            +
                        if share_parameter
         | 
| 126 | 
            +
                        else None
         | 
| 127 | 
            +
                    )
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                    for i in range(n_flows):
         | 
| 130 | 
            +
                        self.flows.append(
         | 
| 131 | 
            +
                            modules.TransformerCouplingLayer(
         | 
| 132 | 
            +
                                channels,
         | 
| 133 | 
            +
                                hidden_channels,
         | 
| 134 | 
            +
                                kernel_size,
         | 
| 135 | 
            +
                                n_layers,
         | 
| 136 | 
            +
                                n_heads,
         | 
| 137 | 
            +
                                p_dropout,
         | 
| 138 | 
            +
                                filter_channels,
         | 
| 139 | 
            +
                                mean_only=True,
         | 
| 140 | 
            +
                                wn_sharing_parameter=self.wn,
         | 
| 141 | 
            +
                                gin_channels=self.gin_channels,
         | 
| 142 | 
            +
                            )
         | 
| 143 | 
            +
                        )
         | 
| 144 | 
            +
                        self.flows.append(modules.Flip())
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                def forward(self, x, x_mask, g=None, reverse=False):
         | 
| 147 | 
            +
                    if not reverse:
         | 
| 148 | 
            +
                        for flow in self.flows:
         | 
| 149 | 
            +
                            x, _ = flow(x, x_mask, g=g, reverse=reverse)
         | 
| 150 | 
            +
                    else:
         | 
| 151 | 
            +
                        for flow in reversed(self.flows):
         | 
| 152 | 
            +
                            x = flow(x, x_mask, g=g, reverse=reverse)
         | 
| 153 | 
            +
                    return x
         | 
| 154 | 
            +
             | 
| 155 | 
            +
             | 
| 156 | 
            +
            class StochasticDurationPredictor(nn.Module):
         | 
| 157 | 
            +
                def __init__(
         | 
| 158 | 
            +
                    self,
         | 
| 159 | 
            +
                    in_channels,
         | 
| 160 | 
            +
                    filter_channels,
         | 
| 161 | 
            +
                    kernel_size,
         | 
| 162 | 
            +
                    p_dropout,
         | 
| 163 | 
            +
                    n_flows=4,
         | 
| 164 | 
            +
                    gin_channels=0,
         | 
| 165 | 
            +
                ):
         | 
| 166 | 
            +
                    super().__init__()
         | 
| 167 | 
            +
                    filter_channels = in_channels  # it needs to be removed from future version.
         | 
| 168 | 
            +
                    self.in_channels = in_channels
         | 
| 169 | 
            +
                    self.filter_channels = filter_channels
         | 
| 170 | 
            +
                    self.kernel_size = kernel_size
         | 
| 171 | 
            +
                    self.p_dropout = p_dropout
         | 
| 172 | 
            +
                    self.n_flows = n_flows
         | 
| 173 | 
            +
                    self.gin_channels = gin_channels
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                    self.log_flow = modules.Log()
         | 
| 176 | 
            +
                    self.flows = nn.ModuleList()
         | 
| 177 | 
            +
                    self.flows.append(modules.ElementwiseAffine(2))
         | 
| 178 | 
            +
                    for i in range(n_flows):
         | 
| 179 | 
            +
                        self.flows.append(
         | 
| 180 | 
            +
                            modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
         | 
| 181 | 
            +
                        )
         | 
| 182 | 
            +
                        self.flows.append(modules.Flip())
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                    self.post_pre = nn.Conv1d(1, filter_channels, 1)
         | 
| 185 | 
            +
                    self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
         | 
| 186 | 
            +
                    self.post_convs = modules.DDSConv(
         | 
| 187 | 
            +
                        filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
         | 
| 188 | 
            +
                    )
         | 
| 189 | 
            +
                    self.post_flows = nn.ModuleList()
         | 
| 190 | 
            +
                    self.post_flows.append(modules.ElementwiseAffine(2))
         | 
| 191 | 
            +
                    for i in range(4):
         | 
| 192 | 
            +
                        self.post_flows.append(
         | 
| 193 | 
            +
                            modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
         | 
| 194 | 
            +
                        )
         | 
| 195 | 
            +
                        self.post_flows.append(modules.Flip())
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                    self.pre = nn.Conv1d(in_channels, filter_channels, 1)
         | 
| 198 | 
            +
                    self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
         | 
| 199 | 
            +
                    self.convs = modules.DDSConv(
         | 
| 200 | 
            +
                        filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
         | 
| 201 | 
            +
                    )
         | 
| 202 | 
            +
                    if gin_channels != 0:
         | 
| 203 | 
            +
                        self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
         | 
| 206 | 
            +
                    x = torch.detach(x)
         | 
| 207 | 
            +
                    x = self.pre(x)
         | 
| 208 | 
            +
                    if g is not None:
         | 
| 209 | 
            +
                        g = torch.detach(g)
         | 
| 210 | 
            +
                        x = x + self.cond(g)
         | 
| 211 | 
            +
                    x = self.convs(x, x_mask)
         | 
| 212 | 
            +
                    x = self.proj(x) * x_mask
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                    if not reverse:
         | 
| 215 | 
            +
                        flows = self.flows
         | 
| 216 | 
            +
                        assert w is not None
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                        logdet_tot_q = 0
         | 
| 219 | 
            +
                        h_w = self.post_pre(w)
         | 
| 220 | 
            +
                        h_w = self.post_convs(h_w, x_mask)
         | 
| 221 | 
            +
                        h_w = self.post_proj(h_w) * x_mask
         | 
| 222 | 
            +
                        e_q = (
         | 
| 223 | 
            +
                            torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
         | 
| 224 | 
            +
                            * x_mask
         | 
| 225 | 
            +
                        )
         | 
| 226 | 
            +
                        z_q = e_q
         | 
| 227 | 
            +
                        for flow in self.post_flows:
         | 
| 228 | 
            +
                            z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
         | 
| 229 | 
            +
                            logdet_tot_q += logdet_q
         | 
| 230 | 
            +
                        z_u, z1 = torch.split(z_q, [1, 1], 1)
         | 
| 231 | 
            +
                        u = torch.sigmoid(z_u) * x_mask
         | 
| 232 | 
            +
                        z0 = (w - u) * x_mask
         | 
| 233 | 
            +
                        logdet_tot_q += torch.sum(
         | 
| 234 | 
            +
                            (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
         | 
| 235 | 
            +
                        )
         | 
| 236 | 
            +
                        logq = (
         | 
| 237 | 
            +
                            torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
         | 
| 238 | 
            +
                            - logdet_tot_q
         | 
| 239 | 
            +
                        )
         | 
| 240 | 
            +
             | 
| 241 | 
            +
                        logdet_tot = 0
         | 
| 242 | 
            +
                        z0, logdet = self.log_flow(z0, x_mask)
         | 
| 243 | 
            +
                        logdet_tot += logdet
         | 
| 244 | 
            +
                        z = torch.cat([z0, z1], 1)
         | 
| 245 | 
            +
                        for flow in flows:
         | 
| 246 | 
            +
                            z, logdet = flow(z, x_mask, g=x, reverse=reverse)
         | 
| 247 | 
            +
                            logdet_tot = logdet_tot + logdet
         | 
| 248 | 
            +
                        nll = (
         | 
| 249 | 
            +
                            torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
         | 
| 250 | 
            +
                            - logdet_tot
         | 
| 251 | 
            +
                        )
         | 
| 252 | 
            +
                        return nll + logq  # [b]
         | 
| 253 | 
            +
                    else:
         | 
| 254 | 
            +
                        flows = list(reversed(self.flows))
         | 
| 255 | 
            +
                        flows = flows[:-2] + [flows[-1]]  # remove a useless vflow
         | 
| 256 | 
            +
                        z = (
         | 
| 257 | 
            +
                            torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
         | 
| 258 | 
            +
                            * noise_scale
         | 
| 259 | 
            +
                        )
         | 
| 260 | 
            +
                        for flow in flows:
         | 
| 261 | 
            +
                            z = flow(z, x_mask, g=x, reverse=reverse)
         | 
| 262 | 
            +
                        z0, z1 = torch.split(z, [1, 1], 1)
         | 
| 263 | 
            +
                        logw = z0
         | 
| 264 | 
            +
                        return logw
         | 
| 265 | 
            +
             | 
| 266 | 
            +
             | 
| 267 | 
            +
            class DurationPredictor(nn.Module):
         | 
| 268 | 
            +
                def __init__(
         | 
| 269 | 
            +
                    self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
         | 
| 270 | 
            +
                ):
         | 
| 271 | 
            +
                    super().__init__()
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                    self.in_channels = in_channels
         | 
| 274 | 
            +
                    self.filter_channels = filter_channels
         | 
| 275 | 
            +
                    self.kernel_size = kernel_size
         | 
| 276 | 
            +
                    self.p_dropout = p_dropout
         | 
| 277 | 
            +
                    self.gin_channels = gin_channels
         | 
| 278 | 
            +
             | 
| 279 | 
            +
                    self.drop = nn.Dropout(p_dropout)
         | 
| 280 | 
            +
                    self.conv_1 = nn.Conv1d(
         | 
| 281 | 
            +
                        in_channels, filter_channels, kernel_size, padding=kernel_size // 2
         | 
| 282 | 
            +
                    )
         | 
| 283 | 
            +
                    self.norm_1 = modules.LayerNorm(filter_channels)
         | 
| 284 | 
            +
                    self.conv_2 = nn.Conv1d(
         | 
| 285 | 
            +
                        filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
         | 
| 286 | 
            +
                    )
         | 
| 287 | 
            +
                    self.norm_2 = modules.LayerNorm(filter_channels)
         | 
| 288 | 
            +
                    self.proj = nn.Conv1d(filter_channels, 1, 1)
         | 
| 289 | 
            +
             | 
| 290 | 
            +
                    if gin_channels != 0:
         | 
| 291 | 
            +
                        self.cond = nn.Conv1d(gin_channels, in_channels, 1)
         | 
| 292 | 
            +
             | 
| 293 | 
            +
                def forward(self, x, x_mask, g=None):
         | 
| 294 | 
            +
                    x = torch.detach(x)
         | 
| 295 | 
            +
                    if g is not None:
         | 
| 296 | 
            +
                        g = torch.detach(g)
         | 
| 297 | 
            +
                        x = x + self.cond(g)
         | 
| 298 | 
            +
                    x = self.conv_1(x * x_mask)
         | 
| 299 | 
            +
                    x = torch.relu(x)
         | 
| 300 | 
            +
                    x = self.norm_1(x)
         | 
| 301 | 
            +
                    x = self.drop(x)
         | 
| 302 | 
            +
                    x = self.conv_2(x * x_mask)
         | 
| 303 | 
            +
                    x = torch.relu(x)
         | 
| 304 | 
            +
                    x = self.norm_2(x)
         | 
| 305 | 
            +
                    x = self.drop(x)
         | 
| 306 | 
            +
                    x = self.proj(x * x_mask)
         | 
| 307 | 
            +
                    return x * x_mask
         | 
| 308 | 
            +
             | 
| 309 | 
            +
             | 
| 310 | 
            +
            class TextEncoder(nn.Module):
         | 
| 311 | 
            +
                def __init__(
         | 
| 312 | 
            +
                    self,
         | 
| 313 | 
            +
                    n_vocab,
         | 
| 314 | 
            +
                    out_channels,
         | 
| 315 | 
            +
                    hidden_channels,
         | 
| 316 | 
            +
                    filter_channels,
         | 
| 317 | 
            +
                    n_heads,
         | 
| 318 | 
            +
                    n_layers,
         | 
| 319 | 
            +
                    kernel_size,
         | 
| 320 | 
            +
                    p_dropout,
         | 
| 321 | 
            +
                    gin_channels=0,
         | 
| 322 | 
            +
                    num_languages=None,
         | 
| 323 | 
            +
                    num_tones=None,
         | 
| 324 | 
            +
                ):
         | 
| 325 | 
            +
                    super().__init__()
         | 
| 326 | 
            +
                    if num_languages is None:
         | 
| 327 | 
            +
                        from text import num_languages
         | 
| 328 | 
            +
                    if num_tones is None:
         | 
| 329 | 
            +
                        from text import num_tones
         | 
| 330 | 
            +
                    self.n_vocab = n_vocab
         | 
| 331 | 
            +
                    self.out_channels = out_channels
         | 
| 332 | 
            +
                    self.hidden_channels = hidden_channels
         | 
| 333 | 
            +
                    self.filter_channels = filter_channels
         | 
| 334 | 
            +
                    self.n_heads = n_heads
         | 
| 335 | 
            +
                    self.n_layers = n_layers
         | 
| 336 | 
            +
                    self.kernel_size = kernel_size
         | 
| 337 | 
            +
                    self.p_dropout = p_dropout
         | 
| 338 | 
            +
                    self.gin_channels = gin_channels
         | 
| 339 | 
            +
                    self.emb = nn.Embedding(n_vocab, hidden_channels)
         | 
| 340 | 
            +
                    nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
         | 
| 341 | 
            +
                    self.tone_emb = nn.Embedding(num_tones, hidden_channels)
         | 
| 342 | 
            +
                    nn.init.normal_(self.tone_emb.weight, 0.0, hidden_channels**-0.5)
         | 
| 343 | 
            +
                    self.language_emb = nn.Embedding(num_languages, hidden_channels)
         | 
| 344 | 
            +
                    nn.init.normal_(self.language_emb.weight, 0.0, hidden_channels**-0.5)
         | 
| 345 | 
            +
                    self.bert_proj = nn.Conv1d(1024, hidden_channels, 1)
         | 
| 346 | 
            +
                    self.ja_bert_proj = nn.Conv1d(768, hidden_channels, 1)
         | 
| 347 | 
            +
             | 
| 348 | 
            +
                    self.encoder = attentions.Encoder(
         | 
| 349 | 
            +
                        hidden_channels,
         | 
| 350 | 
            +
                        filter_channels,
         | 
| 351 | 
            +
                        n_heads,
         | 
| 352 | 
            +
                        n_layers,
         | 
| 353 | 
            +
                        kernel_size,
         | 
| 354 | 
            +
                        p_dropout,
         | 
| 355 | 
            +
                        gin_channels=self.gin_channels,
         | 
| 356 | 
            +
                    )
         | 
| 357 | 
            +
                    self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
         | 
| 358 | 
            +
             | 
| 359 | 
            +
                def forward(self, x, x_lengths, tone, language, bert, ja_bert, g=None):
         | 
| 360 | 
            +
                    bert_emb = self.bert_proj(bert).transpose(1, 2)
         | 
| 361 | 
            +
                    ja_bert_emb = self.ja_bert_proj(ja_bert).transpose(1, 2)
         | 
| 362 | 
            +
                    x = (
         | 
| 363 | 
            +
                        self.emb(x)
         | 
| 364 | 
            +
                        + self.tone_emb(tone)
         | 
| 365 | 
            +
                        + self.language_emb(language)
         | 
| 366 | 
            +
                        + bert_emb
         | 
| 367 | 
            +
                        + ja_bert_emb
         | 
| 368 | 
            +
                    ) * math.sqrt(
         | 
| 369 | 
            +
                        self.hidden_channels
         | 
| 370 | 
            +
                    )  # [b, t, h]
         | 
| 371 | 
            +
                    x = torch.transpose(x, 1, -1)  # [b, h, t]
         | 
| 372 | 
            +
                    x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
         | 
| 373 | 
            +
                        x.dtype
         | 
| 374 | 
            +
                    )
         | 
| 375 | 
            +
             | 
| 376 | 
            +
                    x = self.encoder(x * x_mask, x_mask, g=g)
         | 
| 377 | 
            +
                    stats = self.proj(x) * x_mask
         | 
| 378 | 
            +
             | 
| 379 | 
            +
                    m, logs = torch.split(stats, self.out_channels, dim=1)
         | 
| 380 | 
            +
                    return x, m, logs, x_mask
         | 
| 381 | 
            +
             | 
| 382 | 
            +
             | 
| 383 | 
            +
            class ResidualCouplingBlock(nn.Module):
         | 
| 384 | 
            +
                def __init__(
         | 
| 385 | 
            +
                    self,
         | 
| 386 | 
            +
                    channels,
         | 
| 387 | 
            +
                    hidden_channels,
         | 
| 388 | 
            +
                    kernel_size,
         | 
| 389 | 
            +
                    dilation_rate,
         | 
| 390 | 
            +
                    n_layers,
         | 
| 391 | 
            +
                    n_flows=4,
         | 
| 392 | 
            +
                    gin_channels=0,
         | 
| 393 | 
            +
                ):
         | 
| 394 | 
            +
                    super().__init__()
         | 
| 395 | 
            +
                    self.channels = channels
         | 
| 396 | 
            +
                    self.hidden_channels = hidden_channels
         | 
| 397 | 
            +
                    self.kernel_size = kernel_size
         | 
| 398 | 
            +
                    self.dilation_rate = dilation_rate
         | 
| 399 | 
            +
                    self.n_layers = n_layers
         | 
| 400 | 
            +
                    self.n_flows = n_flows
         | 
| 401 | 
            +
                    self.gin_channels = gin_channels
         | 
| 402 | 
            +
             | 
| 403 | 
            +
                    self.flows = nn.ModuleList()
         | 
| 404 | 
            +
                    for i in range(n_flows):
         | 
| 405 | 
            +
                        self.flows.append(
         | 
| 406 | 
            +
                            modules.ResidualCouplingLayer(
         | 
| 407 | 
            +
                                channels,
         | 
| 408 | 
            +
                                hidden_channels,
         | 
| 409 | 
            +
                                kernel_size,
         | 
| 410 | 
            +
                                dilation_rate,
         | 
| 411 | 
            +
                                n_layers,
         | 
| 412 | 
            +
                                gin_channels=gin_channels,
         | 
| 413 | 
            +
                                mean_only=True,
         | 
| 414 | 
            +
                            )
         | 
| 415 | 
            +
                        )
         | 
| 416 | 
            +
                        self.flows.append(modules.Flip())
         | 
| 417 | 
            +
             | 
| 418 | 
            +
                def forward(self, x, x_mask, g=None, reverse=False):
         | 
| 419 | 
            +
                    if not reverse:
         | 
| 420 | 
            +
                        for flow in self.flows:
         | 
| 421 | 
            +
                            x, _ = flow(x, x_mask, g=g, reverse=reverse)
         | 
| 422 | 
            +
                    else:
         | 
| 423 | 
            +
                        for flow in reversed(self.flows):
         | 
| 424 | 
            +
                            x = flow(x, x_mask, g=g, reverse=reverse)
         | 
| 425 | 
            +
                    return x
         | 
| 426 | 
            +
             | 
| 427 | 
            +
             | 
| 428 | 
            +
            class PosteriorEncoder(nn.Module):
         | 
| 429 | 
            +
                def __init__(
         | 
| 430 | 
            +
                    self,
         | 
| 431 | 
            +
                    in_channels,
         | 
| 432 | 
            +
                    out_channels,
         | 
| 433 | 
            +
                    hidden_channels,
         | 
| 434 | 
            +
                    kernel_size,
         | 
| 435 | 
            +
                    dilation_rate,
         | 
| 436 | 
            +
                    n_layers,
         | 
| 437 | 
            +
                    gin_channels=0,
         | 
| 438 | 
            +
                ):
         | 
| 439 | 
            +
                    super().__init__()
         | 
| 440 | 
            +
                    self.in_channels = in_channels
         | 
| 441 | 
            +
                    self.out_channels = out_channels
         | 
| 442 | 
            +
                    self.hidden_channels = hidden_channels
         | 
| 443 | 
            +
                    self.kernel_size = kernel_size
         | 
| 444 | 
            +
                    self.dilation_rate = dilation_rate
         | 
| 445 | 
            +
                    self.n_layers = n_layers
         | 
| 446 | 
            +
                    self.gin_channels = gin_channels
         | 
| 447 | 
            +
             | 
| 448 | 
            +
                    self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
         | 
| 449 | 
            +
                    self.enc = modules.WN(
         | 
| 450 | 
            +
                        hidden_channels,
         | 
| 451 | 
            +
                        kernel_size,
         | 
| 452 | 
            +
                        dilation_rate,
         | 
| 453 | 
            +
                        n_layers,
         | 
| 454 | 
            +
                        gin_channels=gin_channels,
         | 
| 455 | 
            +
                    )
         | 
| 456 | 
            +
                    self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
         | 
| 457 | 
            +
             | 
| 458 | 
            +
                def forward(self, x, x_lengths, g=None, tau=1.0):
         | 
| 459 | 
            +
                    x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
         | 
| 460 | 
            +
                        x.dtype
         | 
| 461 | 
            +
                    )
         | 
| 462 | 
            +
                    x = self.pre(x) * x_mask
         | 
| 463 | 
            +
                    x = self.enc(x, x_mask, g=g)
         | 
| 464 | 
            +
                    stats = self.proj(x) * x_mask
         | 
| 465 | 
            +
                    m, logs = torch.split(stats, self.out_channels, dim=1)
         | 
| 466 | 
            +
                    z = (m + torch.randn_like(m) * tau * torch.exp(logs)) * x_mask
         | 
| 467 | 
            +
                    return z, m, logs, x_mask
         | 
| 468 | 
            +
             | 
| 469 | 
            +
             | 
| 470 | 
            +
            class Generator(torch.nn.Module):
         | 
| 471 | 
            +
                def __init__(
         | 
| 472 | 
            +
                    self,
         | 
| 473 | 
            +
                    initial_channel,
         | 
| 474 | 
            +
                    resblock,
         | 
| 475 | 
            +
                    resblock_kernel_sizes,
         | 
| 476 | 
            +
                    resblock_dilation_sizes,
         | 
| 477 | 
            +
                    upsample_rates,
         | 
| 478 | 
            +
                    upsample_initial_channel,
         | 
| 479 | 
            +
                    upsample_kernel_sizes,
         | 
| 480 | 
            +
                    gin_channels=0,
         | 
| 481 | 
            +
                ):
         | 
| 482 | 
            +
                    super(Generator, self).__init__()
         | 
| 483 | 
            +
                    self.num_kernels = len(resblock_kernel_sizes)
         | 
| 484 | 
            +
                    self.num_upsamples = len(upsample_rates)
         | 
| 485 | 
            +
                    self.conv_pre = Conv1d(
         | 
| 486 | 
            +
                        initial_channel, upsample_initial_channel, 7, 1, padding=3
         | 
| 487 | 
            +
                    )
         | 
| 488 | 
            +
                    resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
         | 
| 489 | 
            +
             | 
| 490 | 
            +
                    self.ups = nn.ModuleList()
         | 
| 491 | 
            +
                    for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
         | 
| 492 | 
            +
                        self.ups.append(
         | 
| 493 | 
            +
                            weight_norm(
         | 
| 494 | 
            +
                                ConvTranspose1d(
         | 
| 495 | 
            +
                                    upsample_initial_channel // (2**i),
         | 
| 496 | 
            +
                                    upsample_initial_channel // (2 ** (i + 1)),
         | 
| 497 | 
            +
                                    k,
         | 
| 498 | 
            +
                                    u,
         | 
| 499 | 
            +
                                    padding=(k - u) // 2,
         | 
| 500 | 
            +
                                )
         | 
| 501 | 
            +
                            )
         | 
| 502 | 
            +
                        )
         | 
| 503 | 
            +
             | 
| 504 | 
            +
                    self.resblocks = nn.ModuleList()
         | 
| 505 | 
            +
                    for i in range(len(self.ups)):
         | 
| 506 | 
            +
                        ch = upsample_initial_channel // (2 ** (i + 1))
         | 
| 507 | 
            +
                        for j, (k, d) in enumerate(
         | 
| 508 | 
            +
                            zip(resblock_kernel_sizes, resblock_dilation_sizes)
         | 
| 509 | 
            +
                        ):
         | 
| 510 | 
            +
                            self.resblocks.append(resblock(ch, k, d))
         | 
| 511 | 
            +
             | 
| 512 | 
            +
                    self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
         | 
| 513 | 
            +
                    self.ups.apply(init_weights)
         | 
| 514 | 
            +
             | 
| 515 | 
            +
                    if gin_channels != 0:
         | 
| 516 | 
            +
                        self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
         | 
| 517 | 
            +
             | 
| 518 | 
            +
                def forward(self, x, g=None):
         | 
| 519 | 
            +
                    x = self.conv_pre(x)
         | 
| 520 | 
            +
                    if g is not None:
         | 
| 521 | 
            +
                        x = x + self.cond(g)
         | 
| 522 | 
            +
             | 
| 523 | 
            +
                    for i in range(self.num_upsamples):
         | 
| 524 | 
            +
                        x = F.leaky_relu(x, modules.LRELU_SLOPE)
         | 
| 525 | 
            +
                        x = self.ups[i](x)
         | 
| 526 | 
            +
                        xs = None
         | 
| 527 | 
            +
                        for j in range(self.num_kernels):
         | 
| 528 | 
            +
                            if xs is None:
         | 
| 529 | 
            +
                                xs = self.resblocks[i * self.num_kernels + j](x)
         | 
| 530 | 
            +
                            else:
         | 
| 531 | 
            +
                                xs += self.resblocks[i * self.num_kernels + j](x)
         | 
| 532 | 
            +
                        x = xs / self.num_kernels
         | 
| 533 | 
            +
                    x = F.leaky_relu(x)
         | 
| 534 | 
            +
                    x = self.conv_post(x)
         | 
| 535 | 
            +
                    x = torch.tanh(x)
         | 
| 536 | 
            +
             | 
| 537 | 
            +
                    return x
         | 
| 538 | 
            +
             | 
| 539 | 
            +
                def remove_weight_norm(self):
         | 
| 540 | 
            +
                    print("Removing weight norm...")
         | 
| 541 | 
            +
                    for layer in self.ups:
         | 
| 542 | 
            +
                        remove_weight_norm(layer)
         | 
| 543 | 
            +
                    for layer in self.resblocks:
         | 
| 544 | 
            +
                        layer.remove_weight_norm()
         | 
| 545 | 
            +
             | 
| 546 | 
            +
             | 
| 547 | 
            +
            class DiscriminatorP(torch.nn.Module):
         | 
| 548 | 
            +
                def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
         | 
| 549 | 
            +
                    super(DiscriminatorP, self).__init__()
         | 
| 550 | 
            +
                    self.period = period
         | 
| 551 | 
            +
                    self.use_spectral_norm = use_spectral_norm
         | 
| 552 | 
            +
                    norm_f = weight_norm if use_spectral_norm is False else spectral_norm
         | 
| 553 | 
            +
                    self.convs = nn.ModuleList(
         | 
| 554 | 
            +
                        [
         | 
| 555 | 
            +
                            norm_f(
         | 
| 556 | 
            +
                                Conv2d(
         | 
| 557 | 
            +
                                    1,
         | 
| 558 | 
            +
                                    32,
         | 
| 559 | 
            +
                                    (kernel_size, 1),
         | 
| 560 | 
            +
                                    (stride, 1),
         | 
| 561 | 
            +
                                    padding=(get_padding(kernel_size, 1), 0),
         | 
| 562 | 
            +
                                )
         | 
| 563 | 
            +
                            ),
         | 
| 564 | 
            +
                            norm_f(
         | 
| 565 | 
            +
                                Conv2d(
         | 
| 566 | 
            +
                                    32,
         | 
| 567 | 
            +
                                    128,
         | 
| 568 | 
            +
                                    (kernel_size, 1),
         | 
| 569 | 
            +
                                    (stride, 1),
         | 
| 570 | 
            +
                                    padding=(get_padding(kernel_size, 1), 0),
         | 
| 571 | 
            +
                                )
         | 
| 572 | 
            +
                            ),
         | 
| 573 | 
            +
                            norm_f(
         | 
| 574 | 
            +
                                Conv2d(
         | 
| 575 | 
            +
                                    128,
         | 
| 576 | 
            +
                                    512,
         | 
| 577 | 
            +
                                    (kernel_size, 1),
         | 
| 578 | 
            +
                                    (stride, 1),
         | 
| 579 | 
            +
                                    padding=(get_padding(kernel_size, 1), 0),
         | 
| 580 | 
            +
                                )
         | 
| 581 | 
            +
                            ),
         | 
| 582 | 
            +
                            norm_f(
         | 
| 583 | 
            +
                                Conv2d(
         | 
| 584 | 
            +
                                    512,
         | 
| 585 | 
            +
                                    1024,
         | 
| 586 | 
            +
                                    (kernel_size, 1),
         | 
| 587 | 
            +
                                    (stride, 1),
         | 
| 588 | 
            +
                                    padding=(get_padding(kernel_size, 1), 0),
         | 
| 589 | 
            +
                                )
         | 
| 590 | 
            +
                            ),
         | 
| 591 | 
            +
                            norm_f(
         | 
| 592 | 
            +
                                Conv2d(
         | 
| 593 | 
            +
                                    1024,
         | 
| 594 | 
            +
                                    1024,
         | 
| 595 | 
            +
                                    (kernel_size, 1),
         | 
| 596 | 
            +
                                    1,
         | 
| 597 | 
            +
                                    padding=(get_padding(kernel_size, 1), 0),
         | 
| 598 | 
            +
                                )
         | 
| 599 | 
            +
                            ),
         | 
| 600 | 
            +
                        ]
         | 
| 601 | 
            +
                    )
         | 
| 602 | 
            +
                    self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
         | 
| 603 | 
            +
             | 
| 604 | 
            +
                def forward(self, x):
         | 
| 605 | 
            +
                    fmap = []
         | 
| 606 | 
            +
             | 
| 607 | 
            +
                    # 1d to 2d
         | 
| 608 | 
            +
                    b, c, t = x.shape
         | 
| 609 | 
            +
                    if t % self.period != 0:  # pad first
         | 
| 610 | 
            +
                        n_pad = self.period - (t % self.period)
         | 
| 611 | 
            +
                        x = F.pad(x, (0, n_pad), "reflect")
         | 
| 612 | 
            +
                        t = t + n_pad
         | 
| 613 | 
            +
                    x = x.view(b, c, t // self.period, self.period)
         | 
| 614 | 
            +
             | 
| 615 | 
            +
                    for layer in self.convs:
         | 
| 616 | 
            +
                        x = layer(x)
         | 
| 617 | 
            +
                        x = F.leaky_relu(x, modules.LRELU_SLOPE)
         | 
| 618 | 
            +
                        fmap.append(x)
         | 
| 619 | 
            +
                    x = self.conv_post(x)
         | 
| 620 | 
            +
                    fmap.append(x)
         | 
| 621 | 
            +
                    x = torch.flatten(x, 1, -1)
         | 
| 622 | 
            +
             | 
| 623 | 
            +
                    return x, fmap
         | 
| 624 | 
            +
             | 
| 625 | 
            +
             | 
| 626 | 
            +
            class DiscriminatorS(torch.nn.Module):
         | 
| 627 | 
            +
                def __init__(self, use_spectral_norm=False):
         | 
| 628 | 
            +
                    super(DiscriminatorS, self).__init__()
         | 
| 629 | 
            +
                    norm_f = weight_norm if use_spectral_norm is False else spectral_norm
         | 
| 630 | 
            +
                    self.convs = nn.ModuleList(
         | 
| 631 | 
            +
                        [
         | 
| 632 | 
            +
                            norm_f(Conv1d(1, 16, 15, 1, padding=7)),
         | 
| 633 | 
            +
                            norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
         | 
| 634 | 
            +
                            norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
         | 
| 635 | 
            +
                            norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
         | 
| 636 | 
            +
                            norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
         | 
| 637 | 
            +
                            norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
         | 
| 638 | 
            +
                        ]
         | 
| 639 | 
            +
                    )
         | 
| 640 | 
            +
                    self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
         | 
| 641 | 
            +
             | 
| 642 | 
            +
                def forward(self, x):
         | 
| 643 | 
            +
                    fmap = []
         | 
| 644 | 
            +
             | 
| 645 | 
            +
                    for layer in self.convs:
         | 
| 646 | 
            +
                        x = layer(x)
         | 
| 647 | 
            +
                        x = F.leaky_relu(x, modules.LRELU_SLOPE)
         | 
| 648 | 
            +
                        fmap.append(x)
         | 
| 649 | 
            +
                    x = self.conv_post(x)
         | 
| 650 | 
            +
                    fmap.append(x)
         | 
| 651 | 
            +
                    x = torch.flatten(x, 1, -1)
         | 
| 652 | 
            +
             | 
| 653 | 
            +
                    return x, fmap
         | 
| 654 | 
            +
             | 
| 655 | 
            +
             | 
| 656 | 
            +
            class MultiPeriodDiscriminator(torch.nn.Module):
         | 
| 657 | 
            +
                def __init__(self, use_spectral_norm=False):
         | 
| 658 | 
            +
                    super(MultiPeriodDiscriminator, self).__init__()
         | 
| 659 | 
            +
                    periods = [2, 3, 5, 7, 11]
         | 
| 660 | 
            +
             | 
| 661 | 
            +
                    discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
         | 
| 662 | 
            +
                    discs = discs + [
         | 
| 663 | 
            +
                        DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
         | 
| 664 | 
            +
                    ]
         | 
| 665 | 
            +
                    self.discriminators = nn.ModuleList(discs)
         | 
| 666 | 
            +
             | 
| 667 | 
            +
                def forward(self, y, y_hat):
         | 
| 668 | 
            +
                    y_d_rs = []
         | 
| 669 | 
            +
                    y_d_gs = []
         | 
| 670 | 
            +
                    fmap_rs = []
         | 
| 671 | 
            +
                    fmap_gs = []
         | 
| 672 | 
            +
                    for i, d in enumerate(self.discriminators):
         | 
| 673 | 
            +
                        y_d_r, fmap_r = d(y)
         | 
| 674 | 
            +
                        y_d_g, fmap_g = d(y_hat)
         | 
| 675 | 
            +
                        y_d_rs.append(y_d_r)
         | 
| 676 | 
            +
                        y_d_gs.append(y_d_g)
         | 
| 677 | 
            +
                        fmap_rs.append(fmap_r)
         | 
| 678 | 
            +
                        fmap_gs.append(fmap_g)
         | 
| 679 | 
            +
             | 
| 680 | 
            +
                    return y_d_rs, y_d_gs, fmap_rs, fmap_gs
         | 
| 681 | 
            +
             | 
| 682 | 
            +
             | 
| 683 | 
            +
            class ReferenceEncoder(nn.Module):
         | 
| 684 | 
            +
                """
         | 
| 685 | 
            +
                inputs --- [N, Ty/r, n_mels*r]  mels
         | 
| 686 | 
            +
                outputs --- [N, ref_enc_gru_size]
         | 
| 687 | 
            +
                """
         | 
| 688 | 
            +
             | 
| 689 | 
            +
                def __init__(self, spec_channels, gin_channels=0, layernorm=False):
         | 
| 690 | 
            +
                    super().__init__()
         | 
| 691 | 
            +
                    self.spec_channels = spec_channels
         | 
| 692 | 
            +
                    ref_enc_filters = [32, 32, 64, 64, 128, 128]
         | 
| 693 | 
            +
                    K = len(ref_enc_filters)
         | 
| 694 | 
            +
                    filters = [1] + ref_enc_filters
         | 
| 695 | 
            +
                    convs = [
         | 
| 696 | 
            +
                        weight_norm(
         | 
| 697 | 
            +
                            nn.Conv2d(
         | 
| 698 | 
            +
                                in_channels=filters[i],
         | 
| 699 | 
            +
                                out_channels=filters[i + 1],
         | 
| 700 | 
            +
                                kernel_size=(3, 3),
         | 
| 701 | 
            +
                                stride=(2, 2),
         | 
| 702 | 
            +
                                padding=(1, 1),
         | 
| 703 | 
            +
                            )
         | 
| 704 | 
            +
                        )
         | 
| 705 | 
            +
                        for i in range(K)
         | 
| 706 | 
            +
                    ]
         | 
| 707 | 
            +
                    self.convs = nn.ModuleList(convs)
         | 
| 708 | 
            +
                    # self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)]) # noqa: E501
         | 
| 709 | 
            +
             | 
| 710 | 
            +
                    out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
         | 
| 711 | 
            +
                    self.gru = nn.GRU(
         | 
| 712 | 
            +
                        input_size=ref_enc_filters[-1] * out_channels,
         | 
| 713 | 
            +
                        hidden_size=256 // 2,
         | 
| 714 | 
            +
                        batch_first=True,
         | 
| 715 | 
            +
                    )
         | 
| 716 | 
            +
                    self.proj = nn.Linear(128, gin_channels)
         | 
| 717 | 
            +
                    if layernorm:
         | 
| 718 | 
            +
                        self.layernorm = nn.LayerNorm(self.spec_channels)
         | 
| 719 | 
            +
                        print('[Ref Enc]: using layer norm')
         | 
| 720 | 
            +
                    else:
         | 
| 721 | 
            +
                        self.layernorm = None
         | 
| 722 | 
            +
             | 
| 723 | 
            +
                def forward(self, inputs, mask=None):
         | 
| 724 | 
            +
                    N = inputs.size(0)
         | 
| 725 | 
            +
             | 
| 726 | 
            +
                    out = inputs.view(N, 1, -1, self.spec_channels)  # [N, 1, Ty, n_freqs]
         | 
| 727 | 
            +
                    if self.layernorm is not None:
         | 
| 728 | 
            +
                        out = self.layernorm(out)
         | 
| 729 | 
            +
             | 
| 730 | 
            +
                    for conv in self.convs:
         | 
| 731 | 
            +
                        out = conv(out)
         | 
| 732 | 
            +
                        # out = wn(out)
         | 
| 733 | 
            +
                        out = F.relu(out)  # [N, 128, Ty//2^K, n_mels//2^K]
         | 
| 734 | 
            +
             | 
| 735 | 
            +
                    out = out.transpose(1, 2)  # [N, Ty//2^K, 128, n_mels//2^K]
         | 
| 736 | 
            +
                    T = out.size(1)
         | 
| 737 | 
            +
                    N = out.size(0)
         | 
| 738 | 
            +
                    out = out.contiguous().view(N, T, -1)  # [N, Ty//2^K, 128*n_mels//2^K]
         | 
| 739 | 
            +
             | 
| 740 | 
            +
                    self.gru.flatten_parameters()
         | 
| 741 | 
            +
                    memory, out = self.gru(out)  # out --- [1, N, 128]
         | 
| 742 | 
            +
             | 
| 743 | 
            +
                    return self.proj(out.squeeze(0))
         | 
| 744 | 
            +
             | 
| 745 | 
            +
                def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
         | 
| 746 | 
            +
                    for i in range(n_convs):
         | 
| 747 | 
            +
                        L = (L - kernel_size + 2 * pad) // stride + 1
         | 
| 748 | 
            +
                    return L
         | 
| 749 | 
            +
             | 
| 750 | 
            +
             | 
| 751 | 
            +
            class SynthesizerTrn(nn.Module):
         | 
| 752 | 
            +
                """
         | 
| 753 | 
            +
                Synthesizer for Training
         | 
| 754 | 
            +
                """
         | 
| 755 | 
            +
             | 
| 756 | 
            +
                def __init__(
         | 
| 757 | 
            +
                    self,
         | 
| 758 | 
            +
                    n_vocab,
         | 
| 759 | 
            +
                    spec_channels,
         | 
| 760 | 
            +
                    segment_size,
         | 
| 761 | 
            +
                    inter_channels,
         | 
| 762 | 
            +
                    hidden_channels,
         | 
| 763 | 
            +
                    filter_channels,
         | 
| 764 | 
            +
                    n_heads,
         | 
| 765 | 
            +
                    n_layers,
         | 
| 766 | 
            +
                    kernel_size,
         | 
| 767 | 
            +
                    p_dropout,
         | 
| 768 | 
            +
                    resblock,
         | 
| 769 | 
            +
                    resblock_kernel_sizes,
         | 
| 770 | 
            +
                    resblock_dilation_sizes,
         | 
| 771 | 
            +
                    upsample_rates,
         | 
| 772 | 
            +
                    upsample_initial_channel,
         | 
| 773 | 
            +
                    upsample_kernel_sizes,
         | 
| 774 | 
            +
                    n_speakers=256,
         | 
| 775 | 
            +
                    gin_channels=256,
         | 
| 776 | 
            +
                    use_sdp=True,
         | 
| 777 | 
            +
                    n_flow_layer=4,
         | 
| 778 | 
            +
                    n_layers_trans_flow=6,
         | 
| 779 | 
            +
                    flow_share_parameter=False,
         | 
| 780 | 
            +
                    use_transformer_flow=True,
         | 
| 781 | 
            +
                    use_vc=False,
         | 
| 782 | 
            +
                    num_languages=None,
         | 
| 783 | 
            +
                    num_tones=None,
         | 
| 784 | 
            +
                    norm_refenc=False,
         | 
| 785 | 
            +
                    use_se=False,
         | 
| 786 | 
            +
                    **kwargs
         | 
| 787 | 
            +
                ):
         | 
| 788 | 
            +
                    super().__init__()
         | 
| 789 | 
            +
                    self.n_vocab = n_vocab
         | 
| 790 | 
            +
                    self.spec_channels = spec_channels
         | 
| 791 | 
            +
                    self.inter_channels = inter_channels
         | 
| 792 | 
            +
                    self.hidden_channels = hidden_channels
         | 
| 793 | 
            +
                    self.filter_channels = filter_channels
         | 
| 794 | 
            +
                    self.n_heads = n_heads
         | 
| 795 | 
            +
                    self.n_layers = n_layers
         | 
| 796 | 
            +
                    self.kernel_size = kernel_size
         | 
| 797 | 
            +
                    self.p_dropout = p_dropout
         | 
| 798 | 
            +
                    self.resblock = resblock
         | 
| 799 | 
            +
                    self.resblock_kernel_sizes = resblock_kernel_sizes
         | 
| 800 | 
            +
                    self.resblock_dilation_sizes = resblock_dilation_sizes
         | 
| 801 | 
            +
                    self.upsample_rates = upsample_rates
         | 
| 802 | 
            +
                    self.upsample_initial_channel = upsample_initial_channel
         | 
| 803 | 
            +
                    self.upsample_kernel_sizes = upsample_kernel_sizes
         | 
| 804 | 
            +
                    self.segment_size = segment_size
         | 
| 805 | 
            +
                    self.n_speakers = n_speakers
         | 
| 806 | 
            +
                    self.gin_channels = gin_channels
         | 
| 807 | 
            +
                    self.n_layers_trans_flow = n_layers_trans_flow
         | 
| 808 | 
            +
                    self.use_spk_conditioned_encoder = kwargs.get(
         | 
| 809 | 
            +
                        "use_spk_conditioned_encoder", True
         | 
| 810 | 
            +
                    )
         | 
| 811 | 
            +
                    self.use_sdp = use_sdp
         | 
| 812 | 
            +
                    self.use_noise_scaled_mas = kwargs.get("use_noise_scaled_mas", False)
         | 
| 813 | 
            +
                    self.mas_noise_scale_initial = kwargs.get("mas_noise_scale_initial", 0.01)
         | 
| 814 | 
            +
                    self.noise_scale_delta = kwargs.get("noise_scale_delta", 2e-6)
         | 
| 815 | 
            +
                    self.current_mas_noise_scale = self.mas_noise_scale_initial
         | 
| 816 | 
            +
                    if self.use_spk_conditioned_encoder and gin_channels > 0:
         | 
| 817 | 
            +
                        self.enc_gin_channels = gin_channels
         | 
| 818 | 
            +
                    else:
         | 
| 819 | 
            +
                        self.enc_gin_channels = 0
         | 
| 820 | 
            +
                    self.enc_p = TextEncoder(
         | 
| 821 | 
            +
                        n_vocab,
         | 
| 822 | 
            +
                        inter_channels,
         | 
| 823 | 
            +
                        hidden_channels,
         | 
| 824 | 
            +
                        filter_channels,
         | 
| 825 | 
            +
                        n_heads,
         | 
| 826 | 
            +
                        n_layers,
         | 
| 827 | 
            +
                        kernel_size,
         | 
| 828 | 
            +
                        p_dropout,
         | 
| 829 | 
            +
                        gin_channels=self.enc_gin_channels,
         | 
| 830 | 
            +
                        num_languages=num_languages,
         | 
| 831 | 
            +
                        num_tones=num_tones,
         | 
| 832 | 
            +
                    )
         | 
| 833 | 
            +
                    self.dec = Generator(
         | 
| 834 | 
            +
                        inter_channels,
         | 
| 835 | 
            +
                        resblock,
         | 
| 836 | 
            +
                        resblock_kernel_sizes,
         | 
| 837 | 
            +
                        resblock_dilation_sizes,
         | 
| 838 | 
            +
                        upsample_rates,
         | 
| 839 | 
            +
                        upsample_initial_channel,
         | 
| 840 | 
            +
                        upsample_kernel_sizes,
         | 
| 841 | 
            +
                        gin_channels=gin_channels,
         | 
| 842 | 
            +
                    )
         | 
| 843 | 
            +
                    self.enc_q = PosteriorEncoder(
         | 
| 844 | 
            +
                        spec_channels,
         | 
| 845 | 
            +
                        inter_channels,
         | 
| 846 | 
            +
                        hidden_channels,
         | 
| 847 | 
            +
                        5,
         | 
| 848 | 
            +
                        1,
         | 
| 849 | 
            +
                        16,
         | 
| 850 | 
            +
                        gin_channels=gin_channels,
         | 
| 851 | 
            +
                    )
         | 
| 852 | 
            +
                    if use_transformer_flow:
         | 
| 853 | 
            +
                        self.flow = TransformerCouplingBlock(
         | 
| 854 | 
            +
                            inter_channels,
         | 
| 855 | 
            +
                            hidden_channels,
         | 
| 856 | 
            +
                            filter_channels,
         | 
| 857 | 
            +
                            n_heads,
         | 
| 858 | 
            +
                            n_layers_trans_flow,
         | 
| 859 | 
            +
                            5,
         | 
| 860 | 
            +
                            p_dropout,
         | 
| 861 | 
            +
                            n_flow_layer,
         | 
| 862 | 
            +
                            gin_channels=gin_channels,
         | 
| 863 | 
            +
                            share_parameter=flow_share_parameter,
         | 
| 864 | 
            +
                        )
         | 
| 865 | 
            +
                    else:
         | 
| 866 | 
            +
                        self.flow = ResidualCouplingBlock(
         | 
| 867 | 
            +
                            inter_channels,
         | 
| 868 | 
            +
                            hidden_channels,
         | 
| 869 | 
            +
                            5,
         | 
| 870 | 
            +
                            1,
         | 
| 871 | 
            +
                            n_flow_layer,
         | 
| 872 | 
            +
                            gin_channels=gin_channels,
         | 
| 873 | 
            +
                        )
         | 
| 874 | 
            +
                    self.sdp = StochasticDurationPredictor(
         | 
| 875 | 
            +
                        hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels
         | 
| 876 | 
            +
                    )
         | 
| 877 | 
            +
                    self.dp = DurationPredictor(
         | 
| 878 | 
            +
                        hidden_channels, 256, 3, 0.5, gin_channels=gin_channels
         | 
| 879 | 
            +
                    )
         | 
| 880 | 
            +
             | 
| 881 | 
            +
                    if n_speakers > 1:
         | 
| 882 | 
            +
                        if use_se:
         | 
| 883 | 
            +
                            emb_dim = 512
         | 
| 884 | 
            +
                            self.emb_g = nn.Linear(emb_dim, gin_channels)
         | 
| 885 | 
            +
                        else:
         | 
| 886 | 
            +
                            self.emb_g = nn.Embedding(n_speakers, gin_channels)
         | 
| 887 | 
            +
                    else:
         | 
| 888 | 
            +
                        self.ref_enc = ReferenceEncoder(spec_channels, gin_channels, layernorm=norm_refenc)
         | 
| 889 | 
            +
                    self.use_vc = use_vc
         | 
| 890 | 
            +
                    self.use_se = use_se
         | 
| 891 | 
            +
             | 
| 892 | 
            +
                def forward(self, x, x_lengths, y, y_lengths, sid, tone, language, bert, ja_bert):
         | 
| 893 | 
            +
                    if self.n_speakers > 0:
         | 
| 894 | 
            +
                        g = self.emb_g(sid).unsqueeze(-1)  # [b, h, 1]
         | 
| 895 | 
            +
                    else:
         | 
| 896 | 
            +
                        g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
         | 
| 897 | 
            +
                    if self.use_vc:
         | 
| 898 | 
            +
                        g_p = None
         | 
| 899 | 
            +
                    else:
         | 
| 900 | 
            +
                        g_p = g
         | 
| 901 | 
            +
                    x, m_p, logs_p, x_mask = self.enc_p(
         | 
| 902 | 
            +
                        x, x_lengths, tone, language, bert, ja_bert, g=g_p
         | 
| 903 | 
            +
                    )
         | 
| 904 | 
            +
                    z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
         | 
| 905 | 
            +
                    z_p = self.flow(z, y_mask, g=g)
         | 
| 906 | 
            +
             | 
| 907 | 
            +
                    with torch.no_grad():
         | 
| 908 | 
            +
                        # negative cross-entropy
         | 
| 909 | 
            +
                        s_p_sq_r = torch.exp(-2 * logs_p)  # [b, d, t]
         | 
| 910 | 
            +
                        neg_cent1 = torch.sum(
         | 
| 911 | 
            +
                            -0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True
         | 
| 912 | 
            +
                        )  # [b, 1, t_s]
         | 
| 913 | 
            +
                        neg_cent2 = torch.matmul(
         | 
| 914 | 
            +
                            -0.5 * (z_p**2).transpose(1, 2), s_p_sq_r
         | 
| 915 | 
            +
                        )  # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
         | 
| 916 | 
            +
                        neg_cent3 = torch.matmul(
         | 
| 917 | 
            +
                            z_p.transpose(1, 2), (m_p * s_p_sq_r)
         | 
| 918 | 
            +
                        )  # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
         | 
| 919 | 
            +
                        neg_cent4 = torch.sum(
         | 
| 920 | 
            +
                            -0.5 * (m_p**2) * s_p_sq_r, [1], keepdim=True
         | 
| 921 | 
            +
                        )  # [b, 1, t_s]
         | 
| 922 | 
            +
                        neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
         | 
| 923 | 
            +
                        if self.use_noise_scaled_mas:
         | 
| 924 | 
            +
                            epsilon = (
         | 
| 925 | 
            +
                                torch.std(neg_cent)
         | 
| 926 | 
            +
                                * torch.randn_like(neg_cent)
         | 
| 927 | 
            +
                                * self.current_mas_noise_scale
         | 
| 928 | 
            +
                            )
         | 
| 929 | 
            +
                            neg_cent = neg_cent + epsilon
         | 
| 930 | 
            +
             | 
| 931 | 
            +
                        attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
         | 
| 932 | 
            +
                        attn = (
         | 
| 933 | 
            +
                            monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1))
         | 
| 934 | 
            +
                            .unsqueeze(1)
         | 
| 935 | 
            +
                            .detach()
         | 
| 936 | 
            +
                        )
         | 
| 937 | 
            +
             | 
| 938 | 
            +
                    w = attn.sum(2)
         | 
| 939 | 
            +
             | 
| 940 | 
            +
                    l_length_sdp = self.sdp(x, x_mask, w, g=g)
         | 
| 941 | 
            +
                    l_length_sdp = l_length_sdp / torch.sum(x_mask)
         | 
| 942 | 
            +
             | 
| 943 | 
            +
                    logw_ = torch.log(w + 1e-6) * x_mask
         | 
| 944 | 
            +
                    logw = self.dp(x, x_mask, g=g)
         | 
| 945 | 
            +
                    l_length_dp = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(
         | 
| 946 | 
            +
                        x_mask
         | 
| 947 | 
            +
                    )  # for averaging
         | 
| 948 | 
            +
             | 
| 949 | 
            +
                    l_length = l_length_dp + l_length_sdp
         | 
| 950 | 
            +
             | 
| 951 | 
            +
                    # expand prior
         | 
| 952 | 
            +
                    m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
         | 
| 953 | 
            +
                    logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
         | 
| 954 | 
            +
             | 
| 955 | 
            +
                    z_slice, ids_slice = commons.rand_slice_segments(
         | 
| 956 | 
            +
                        z, y_lengths, self.segment_size
         | 
| 957 | 
            +
                    )
         | 
| 958 | 
            +
                    o = self.dec(z_slice, g=g)
         | 
| 959 | 
            +
                    return (
         | 
| 960 | 
            +
                        o,
         | 
| 961 | 
            +
                        l_length,
         | 
| 962 | 
            +
                        attn,
         | 
| 963 | 
            +
                        ids_slice,
         | 
| 964 | 
            +
                        x_mask,
         | 
| 965 | 
            +
                        y_mask,
         | 
| 966 | 
            +
                        (z, z_p, m_p, logs_p, m_q, logs_q),
         | 
| 967 | 
            +
                        (x, logw, logw_),
         | 
| 968 | 
            +
                    )
         | 
| 969 | 
            +
             | 
| 970 | 
            +
                def infer(
         | 
| 971 | 
            +
                    self,
         | 
| 972 | 
            +
                    x,
         | 
| 973 | 
            +
                    x_lengths,
         | 
| 974 | 
            +
                    sid,
         | 
| 975 | 
            +
                    tone,
         | 
| 976 | 
            +
                    language,
         | 
| 977 | 
            +
                    bert,
         | 
| 978 | 
            +
                    ja_bert,
         | 
| 979 | 
            +
                    noise_scale=0.667,
         | 
| 980 | 
            +
                    length_scale=1,
         | 
| 981 | 
            +
                    noise_scale_w=0.8,
         | 
| 982 | 
            +
                    max_len=None,
         | 
| 983 | 
            +
                    sdp_ratio=0,
         | 
| 984 | 
            +
                    y=None,
         | 
| 985 | 
            +
                    g=None,
         | 
| 986 | 
            +
                ):
         | 
| 987 | 
            +
                    # x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, tone, language, bert)
         | 
| 988 | 
            +
                    # g = self.gst(y)
         | 
| 989 | 
            +
                    if g is None:
         | 
| 990 | 
            +
                        if self.n_speakers > 0:
         | 
| 991 | 
            +
                            g = self.emb_g(sid).unsqueeze(-1)  # [b, h, 1]
         | 
| 992 | 
            +
                        else:
         | 
| 993 | 
            +
                            g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
         | 
| 994 | 
            +
                    if self.use_vc:
         | 
| 995 | 
            +
                        g_p = None
         | 
| 996 | 
            +
                    else:
         | 
| 997 | 
            +
                        g_p = g
         | 
| 998 | 
            +
                    x, m_p, logs_p, x_mask = self.enc_p(
         | 
| 999 | 
            +
                        x, x_lengths, tone, language, bert, ja_bert, g=g_p
         | 
| 1000 | 
            +
                    )
         | 
| 1001 | 
            +
                    logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * (
         | 
| 1002 | 
            +
                        sdp_ratio
         | 
| 1003 | 
            +
                    ) + self.dp(x, x_mask, g=g) * (1 - sdp_ratio)
         | 
| 1004 | 
            +
                    w = torch.exp(logw) * x_mask * length_scale
         | 
| 1005 | 
            +
                    
         | 
| 1006 | 
            +
                    w_ceil = torch.ceil(w)
         | 
| 1007 | 
            +
                    y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
         | 
| 1008 | 
            +
                    y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(
         | 
| 1009 | 
            +
                        x_mask.dtype
         | 
| 1010 | 
            +
                    )
         | 
| 1011 | 
            +
                    attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
         | 
| 1012 | 
            +
                    attn = commons.generate_path(w_ceil, attn_mask)
         | 
| 1013 | 
            +
             | 
| 1014 | 
            +
                    m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(
         | 
| 1015 | 
            +
                        1, 2
         | 
| 1016 | 
            +
                    )  # [b, t', t], [b, t, d] -> [b, d, t']
         | 
| 1017 | 
            +
                    logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(
         | 
| 1018 | 
            +
                        1, 2
         | 
| 1019 | 
            +
                    )  # [b, t', t], [b, t, d] -> [b, d, t']
         | 
| 1020 | 
            +
             | 
| 1021 | 
            +
                    z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
         | 
| 1022 | 
            +
                    z = self.flow(z_p, y_mask, g=g, reverse=True)
         | 
| 1023 | 
            +
                    o = self.dec((z * y_mask)[:, :, :max_len], g=g)
         | 
| 1024 | 
            +
                    # print('max/min of o:', o.max(), o.min())
         | 
| 1025 | 
            +
                    return o, attn, y_mask, (z, z_p, m_p, logs_p)
         | 
| 1026 | 
            +
             | 
| 1027 | 
            +
                def voice_conversion(self, y, y_lengths, sid_src, sid_tgt, tau=1.0):
         | 
| 1028 | 
            +
                    if self.use_se:
         | 
| 1029 | 
            +
                        sid_src = self.emb_g(sid_src).unsqueeze(-1)
         | 
| 1030 | 
            +
                        sid_tgt = self.emb_g(sid_tgt).unsqueeze(-1)
         | 
| 1031 | 
            +
                    
         | 
| 1032 | 
            +
                    g_src = sid_src
         | 
| 1033 | 
            +
                    g_tgt = sid_tgt
         | 
| 1034 | 
            +
                    z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src, tau=tau)
         | 
| 1035 | 
            +
                    z_p = self.flow(z, y_mask, g=g_src)
         | 
| 1036 | 
            +
                    z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
         | 
| 1037 | 
            +
                    o_hat = self.dec(z_hat * y_mask, g=g_tgt)
         | 
| 1038 | 
            +
                    return o_hat, y_mask, (z, z_p, z_hat)
         | 
    	
        melo/modules.py
    ADDED
    
    | @@ -0,0 +1,598 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import math
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            from torch import nn
         | 
| 4 | 
            +
            from torch.nn import functional as F
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from torch.nn import Conv1d
         | 
| 7 | 
            +
            from torch.nn.utils import weight_norm, remove_weight_norm
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from . import commons
         | 
| 10 | 
            +
            from .commons import init_weights, get_padding
         | 
| 11 | 
            +
            from .transforms import piecewise_rational_quadratic_transform
         | 
| 12 | 
            +
            from .attentions import Encoder
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            LRELU_SLOPE = 0.1
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            class LayerNorm(nn.Module):
         | 
| 18 | 
            +
                def __init__(self, channels, eps=1e-5):
         | 
| 19 | 
            +
                    super().__init__()
         | 
| 20 | 
            +
                    self.channels = channels
         | 
| 21 | 
            +
                    self.eps = eps
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                    self.gamma = nn.Parameter(torch.ones(channels))
         | 
| 24 | 
            +
                    self.beta = nn.Parameter(torch.zeros(channels))
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                def forward(self, x):
         | 
| 27 | 
            +
                    x = x.transpose(1, -1)
         | 
| 28 | 
            +
                    x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
         | 
| 29 | 
            +
                    return x.transpose(1, -1)
         | 
| 30 | 
            +
             | 
| 31 | 
            +
             | 
| 32 | 
            +
            class ConvReluNorm(nn.Module):
         | 
| 33 | 
            +
                def __init__(
         | 
| 34 | 
            +
                    self,
         | 
| 35 | 
            +
                    in_channels,
         | 
| 36 | 
            +
                    hidden_channels,
         | 
| 37 | 
            +
                    out_channels,
         | 
| 38 | 
            +
                    kernel_size,
         | 
| 39 | 
            +
                    n_layers,
         | 
| 40 | 
            +
                    p_dropout,
         | 
| 41 | 
            +
                ):
         | 
| 42 | 
            +
                    super().__init__()
         | 
| 43 | 
            +
                    self.in_channels = in_channels
         | 
| 44 | 
            +
                    self.hidden_channels = hidden_channels
         | 
| 45 | 
            +
                    self.out_channels = out_channels
         | 
| 46 | 
            +
                    self.kernel_size = kernel_size
         | 
| 47 | 
            +
                    self.n_layers = n_layers
         | 
| 48 | 
            +
                    self.p_dropout = p_dropout
         | 
| 49 | 
            +
                    assert n_layers > 1, "Number of layers should be larger than 0."
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                    self.conv_layers = nn.ModuleList()
         | 
| 52 | 
            +
                    self.norm_layers = nn.ModuleList()
         | 
| 53 | 
            +
                    self.conv_layers.append(
         | 
| 54 | 
            +
                        nn.Conv1d(
         | 
| 55 | 
            +
                            in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
         | 
| 56 | 
            +
                        )
         | 
| 57 | 
            +
                    )
         | 
| 58 | 
            +
                    self.norm_layers.append(LayerNorm(hidden_channels))
         | 
| 59 | 
            +
                    self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
         | 
| 60 | 
            +
                    for _ in range(n_layers - 1):
         | 
| 61 | 
            +
                        self.conv_layers.append(
         | 
| 62 | 
            +
                            nn.Conv1d(
         | 
| 63 | 
            +
                                hidden_channels,
         | 
| 64 | 
            +
                                hidden_channels,
         | 
| 65 | 
            +
                                kernel_size,
         | 
| 66 | 
            +
                                padding=kernel_size // 2,
         | 
| 67 | 
            +
                            )
         | 
| 68 | 
            +
                        )
         | 
| 69 | 
            +
                        self.norm_layers.append(LayerNorm(hidden_channels))
         | 
| 70 | 
            +
                    self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
         | 
| 71 | 
            +
                    self.proj.weight.data.zero_()
         | 
| 72 | 
            +
                    self.proj.bias.data.zero_()
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                def forward(self, x, x_mask):
         | 
| 75 | 
            +
                    x_org = x
         | 
| 76 | 
            +
                    for i in range(self.n_layers):
         | 
| 77 | 
            +
                        x = self.conv_layers[i](x * x_mask)
         | 
| 78 | 
            +
                        x = self.norm_layers[i](x)
         | 
| 79 | 
            +
                        x = self.relu_drop(x)
         | 
| 80 | 
            +
                    x = x_org + self.proj(x)
         | 
| 81 | 
            +
                    return x * x_mask
         | 
| 82 | 
            +
             | 
| 83 | 
            +
             | 
| 84 | 
            +
            class DDSConv(nn.Module):
         | 
| 85 | 
            +
                """
         | 
| 86 | 
            +
                Dialted and Depth-Separable Convolution
         | 
| 87 | 
            +
                """
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
         | 
| 90 | 
            +
                    super().__init__()
         | 
| 91 | 
            +
                    self.channels = channels
         | 
| 92 | 
            +
                    self.kernel_size = kernel_size
         | 
| 93 | 
            +
                    self.n_layers = n_layers
         | 
| 94 | 
            +
                    self.p_dropout = p_dropout
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                    self.drop = nn.Dropout(p_dropout)
         | 
| 97 | 
            +
                    self.convs_sep = nn.ModuleList()
         | 
| 98 | 
            +
                    self.convs_1x1 = nn.ModuleList()
         | 
| 99 | 
            +
                    self.norms_1 = nn.ModuleList()
         | 
| 100 | 
            +
                    self.norms_2 = nn.ModuleList()
         | 
| 101 | 
            +
                    for i in range(n_layers):
         | 
| 102 | 
            +
                        dilation = kernel_size**i
         | 
| 103 | 
            +
                        padding = (kernel_size * dilation - dilation) // 2
         | 
| 104 | 
            +
                        self.convs_sep.append(
         | 
| 105 | 
            +
                            nn.Conv1d(
         | 
| 106 | 
            +
                                channels,
         | 
| 107 | 
            +
                                channels,
         | 
| 108 | 
            +
                                kernel_size,
         | 
| 109 | 
            +
                                groups=channels,
         | 
| 110 | 
            +
                                dilation=dilation,
         | 
| 111 | 
            +
                                padding=padding,
         | 
| 112 | 
            +
                            )
         | 
| 113 | 
            +
                        )
         | 
| 114 | 
            +
                        self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
         | 
| 115 | 
            +
                        self.norms_1.append(LayerNorm(channels))
         | 
| 116 | 
            +
                        self.norms_2.append(LayerNorm(channels))
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                def forward(self, x, x_mask, g=None):
         | 
| 119 | 
            +
                    if g is not None:
         | 
| 120 | 
            +
                        x = x + g
         | 
| 121 | 
            +
                    for i in range(self.n_layers):
         | 
| 122 | 
            +
                        y = self.convs_sep[i](x * x_mask)
         | 
| 123 | 
            +
                        y = self.norms_1[i](y)
         | 
| 124 | 
            +
                        y = F.gelu(y)
         | 
| 125 | 
            +
                        y = self.convs_1x1[i](y)
         | 
| 126 | 
            +
                        y = self.norms_2[i](y)
         | 
| 127 | 
            +
                        y = F.gelu(y)
         | 
| 128 | 
            +
                        y = self.drop(y)
         | 
| 129 | 
            +
                        x = x + y
         | 
| 130 | 
            +
                    return x * x_mask
         | 
| 131 | 
            +
             | 
| 132 | 
            +
             | 
| 133 | 
            +
            class WN(torch.nn.Module):
         | 
| 134 | 
            +
                def __init__(
         | 
| 135 | 
            +
                    self,
         | 
| 136 | 
            +
                    hidden_channels,
         | 
| 137 | 
            +
                    kernel_size,
         | 
| 138 | 
            +
                    dilation_rate,
         | 
| 139 | 
            +
                    n_layers,
         | 
| 140 | 
            +
                    gin_channels=0,
         | 
| 141 | 
            +
                    p_dropout=0,
         | 
| 142 | 
            +
                ):
         | 
| 143 | 
            +
                    super(WN, self).__init__()
         | 
| 144 | 
            +
                    assert kernel_size % 2 == 1
         | 
| 145 | 
            +
                    self.hidden_channels = hidden_channels
         | 
| 146 | 
            +
                    self.kernel_size = (kernel_size,)
         | 
| 147 | 
            +
                    self.dilation_rate = dilation_rate
         | 
| 148 | 
            +
                    self.n_layers = n_layers
         | 
| 149 | 
            +
                    self.gin_channels = gin_channels
         | 
| 150 | 
            +
                    self.p_dropout = p_dropout
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                    self.in_layers = torch.nn.ModuleList()
         | 
| 153 | 
            +
                    self.res_skip_layers = torch.nn.ModuleList()
         | 
| 154 | 
            +
                    self.drop = nn.Dropout(p_dropout)
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                    if gin_channels != 0:
         | 
| 157 | 
            +
                        cond_layer = torch.nn.Conv1d(
         | 
| 158 | 
            +
                            gin_channels, 2 * hidden_channels * n_layers, 1
         | 
| 159 | 
            +
                        )
         | 
| 160 | 
            +
                        self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                    for i in range(n_layers):
         | 
| 163 | 
            +
                        dilation = dilation_rate**i
         | 
| 164 | 
            +
                        padding = int((kernel_size * dilation - dilation) / 2)
         | 
| 165 | 
            +
                        in_layer = torch.nn.Conv1d(
         | 
| 166 | 
            +
                            hidden_channels,
         | 
| 167 | 
            +
                            2 * hidden_channels,
         | 
| 168 | 
            +
                            kernel_size,
         | 
| 169 | 
            +
                            dilation=dilation,
         | 
| 170 | 
            +
                            padding=padding,
         | 
| 171 | 
            +
                        )
         | 
| 172 | 
            +
                        in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
         | 
| 173 | 
            +
                        self.in_layers.append(in_layer)
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                        # last one is not necessary
         | 
| 176 | 
            +
                        if i < n_layers - 1:
         | 
| 177 | 
            +
                            res_skip_channels = 2 * hidden_channels
         | 
| 178 | 
            +
                        else:
         | 
| 179 | 
            +
                            res_skip_channels = hidden_channels
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                        res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
         | 
| 182 | 
            +
                        res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
         | 
| 183 | 
            +
                        self.res_skip_layers.append(res_skip_layer)
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                def forward(self, x, x_mask, g=None, **kwargs):
         | 
| 186 | 
            +
                    output = torch.zeros_like(x)
         | 
| 187 | 
            +
                    n_channels_tensor = torch.IntTensor([self.hidden_channels])
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                    if g is not None:
         | 
| 190 | 
            +
                        g = self.cond_layer(g)
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                    for i in range(self.n_layers):
         | 
| 193 | 
            +
                        x_in = self.in_layers[i](x)
         | 
| 194 | 
            +
                        if g is not None:
         | 
| 195 | 
            +
                            cond_offset = i * 2 * self.hidden_channels
         | 
| 196 | 
            +
                            g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
         | 
| 197 | 
            +
                        else:
         | 
| 198 | 
            +
                            g_l = torch.zeros_like(x_in)
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                        acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
         | 
| 201 | 
            +
                        acts = self.drop(acts)
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                        res_skip_acts = self.res_skip_layers[i](acts)
         | 
| 204 | 
            +
                        if i < self.n_layers - 1:
         | 
| 205 | 
            +
                            res_acts = res_skip_acts[:, : self.hidden_channels, :]
         | 
| 206 | 
            +
                            x = (x + res_acts) * x_mask
         | 
| 207 | 
            +
                            output = output + res_skip_acts[:, self.hidden_channels :, :]
         | 
| 208 | 
            +
                        else:
         | 
| 209 | 
            +
                            output = output + res_skip_acts
         | 
| 210 | 
            +
                    return output * x_mask
         | 
| 211 | 
            +
             | 
| 212 | 
            +
                def remove_weight_norm(self):
         | 
| 213 | 
            +
                    if self.gin_channels != 0:
         | 
| 214 | 
            +
                        torch.nn.utils.remove_weight_norm(self.cond_layer)
         | 
| 215 | 
            +
                    for l in self.in_layers:
         | 
| 216 | 
            +
                        torch.nn.utils.remove_weight_norm(l)
         | 
| 217 | 
            +
                    for l in self.res_skip_layers:
         | 
| 218 | 
            +
                        torch.nn.utils.remove_weight_norm(l)
         | 
| 219 | 
            +
             | 
| 220 | 
            +
             | 
| 221 | 
            +
            class ResBlock1(torch.nn.Module):
         | 
| 222 | 
            +
                def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
         | 
| 223 | 
            +
                    super(ResBlock1, self).__init__()
         | 
| 224 | 
            +
                    self.convs1 = nn.ModuleList(
         | 
| 225 | 
            +
                        [
         | 
| 226 | 
            +
                            weight_norm(
         | 
| 227 | 
            +
                                Conv1d(
         | 
| 228 | 
            +
                                    channels,
         | 
| 229 | 
            +
                                    channels,
         | 
| 230 | 
            +
                                    kernel_size,
         | 
| 231 | 
            +
                                    1,
         | 
| 232 | 
            +
                                    dilation=dilation[0],
         | 
| 233 | 
            +
                                    padding=get_padding(kernel_size, dilation[0]),
         | 
| 234 | 
            +
                                )
         | 
| 235 | 
            +
                            ),
         | 
| 236 | 
            +
                            weight_norm(
         | 
| 237 | 
            +
                                Conv1d(
         | 
| 238 | 
            +
                                    channels,
         | 
| 239 | 
            +
                                    channels,
         | 
| 240 | 
            +
                                    kernel_size,
         | 
| 241 | 
            +
                                    1,
         | 
| 242 | 
            +
                                    dilation=dilation[1],
         | 
| 243 | 
            +
                                    padding=get_padding(kernel_size, dilation[1]),
         | 
| 244 | 
            +
                                )
         | 
| 245 | 
            +
                            ),
         | 
| 246 | 
            +
                            weight_norm(
         | 
| 247 | 
            +
                                Conv1d(
         | 
| 248 | 
            +
                                    channels,
         | 
| 249 | 
            +
                                    channels,
         | 
| 250 | 
            +
                                    kernel_size,
         | 
| 251 | 
            +
                                    1,
         | 
| 252 | 
            +
                                    dilation=dilation[2],
         | 
| 253 | 
            +
                                    padding=get_padding(kernel_size, dilation[2]),
         | 
| 254 | 
            +
                                )
         | 
| 255 | 
            +
                            ),
         | 
| 256 | 
            +
                        ]
         | 
| 257 | 
            +
                    )
         | 
| 258 | 
            +
                    self.convs1.apply(init_weights)
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                    self.convs2 = nn.ModuleList(
         | 
| 261 | 
            +
                        [
         | 
| 262 | 
            +
                            weight_norm(
         | 
| 263 | 
            +
                                Conv1d(
         | 
| 264 | 
            +
                                    channels,
         | 
| 265 | 
            +
                                    channels,
         | 
| 266 | 
            +
                                    kernel_size,
         | 
| 267 | 
            +
                                    1,
         | 
| 268 | 
            +
                                    dilation=1,
         | 
| 269 | 
            +
                                    padding=get_padding(kernel_size, 1),
         | 
| 270 | 
            +
                                )
         | 
| 271 | 
            +
                            ),
         | 
| 272 | 
            +
                            weight_norm(
         | 
| 273 | 
            +
                                Conv1d(
         | 
| 274 | 
            +
                                    channels,
         | 
| 275 | 
            +
                                    channels,
         | 
| 276 | 
            +
                                    kernel_size,
         | 
| 277 | 
            +
                                    1,
         | 
| 278 | 
            +
                                    dilation=1,
         | 
| 279 | 
            +
                                    padding=get_padding(kernel_size, 1),
         | 
| 280 | 
            +
                                )
         | 
| 281 | 
            +
                            ),
         | 
| 282 | 
            +
                            weight_norm(
         | 
| 283 | 
            +
                                Conv1d(
         | 
| 284 | 
            +
                                    channels,
         | 
| 285 | 
            +
                                    channels,
         | 
| 286 | 
            +
                                    kernel_size,
         | 
| 287 | 
            +
                                    1,
         | 
| 288 | 
            +
                                    dilation=1,
         | 
| 289 | 
            +
                                    padding=get_padding(kernel_size, 1),
         | 
| 290 | 
            +
                                )
         | 
| 291 | 
            +
                            ),
         | 
| 292 | 
            +
                        ]
         | 
| 293 | 
            +
                    )
         | 
| 294 | 
            +
                    self.convs2.apply(init_weights)
         | 
| 295 | 
            +
             | 
| 296 | 
            +
                def forward(self, x, x_mask=None):
         | 
| 297 | 
            +
                    for c1, c2 in zip(self.convs1, self.convs2):
         | 
| 298 | 
            +
                        xt = F.leaky_relu(x, LRELU_SLOPE)
         | 
| 299 | 
            +
                        if x_mask is not None:
         | 
| 300 | 
            +
                            xt = xt * x_mask
         | 
| 301 | 
            +
                        xt = c1(xt)
         | 
| 302 | 
            +
                        xt = F.leaky_relu(xt, LRELU_SLOPE)
         | 
| 303 | 
            +
                        if x_mask is not None:
         | 
| 304 | 
            +
                            xt = xt * x_mask
         | 
| 305 | 
            +
                        xt = c2(xt)
         | 
| 306 | 
            +
                        x = xt + x
         | 
| 307 | 
            +
                    if x_mask is not None:
         | 
| 308 | 
            +
                        x = x * x_mask
         | 
| 309 | 
            +
                    return x
         | 
| 310 | 
            +
             | 
| 311 | 
            +
                def remove_weight_norm(self):
         | 
| 312 | 
            +
                    for l in self.convs1:
         | 
| 313 | 
            +
                        remove_weight_norm(l)
         | 
| 314 | 
            +
                    for l in self.convs2:
         | 
| 315 | 
            +
                        remove_weight_norm(l)
         | 
| 316 | 
            +
             | 
| 317 | 
            +
             | 
| 318 | 
            +
            class ResBlock2(torch.nn.Module):
         | 
| 319 | 
            +
                def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
         | 
| 320 | 
            +
                    super(ResBlock2, self).__init__()
         | 
| 321 | 
            +
                    self.convs = nn.ModuleList(
         | 
| 322 | 
            +
                        [
         | 
| 323 | 
            +
                            weight_norm(
         | 
| 324 | 
            +
                                Conv1d(
         | 
| 325 | 
            +
                                    channels,
         | 
| 326 | 
            +
                                    channels,
         | 
| 327 | 
            +
                                    kernel_size,
         | 
| 328 | 
            +
                                    1,
         | 
| 329 | 
            +
                                    dilation=dilation[0],
         | 
| 330 | 
            +
                                    padding=get_padding(kernel_size, dilation[0]),
         | 
| 331 | 
            +
                                )
         | 
| 332 | 
            +
                            ),
         | 
| 333 | 
            +
                            weight_norm(
         | 
| 334 | 
            +
                                Conv1d(
         | 
| 335 | 
            +
                                    channels,
         | 
| 336 | 
            +
                                    channels,
         | 
| 337 | 
            +
                                    kernel_size,
         | 
| 338 | 
            +
                                    1,
         | 
| 339 | 
            +
                                    dilation=dilation[1],
         | 
| 340 | 
            +
                                    padding=get_padding(kernel_size, dilation[1]),
         | 
| 341 | 
            +
                                )
         | 
| 342 | 
            +
                            ),
         | 
| 343 | 
            +
                        ]
         | 
| 344 | 
            +
                    )
         | 
| 345 | 
            +
                    self.convs.apply(init_weights)
         | 
| 346 | 
            +
             | 
| 347 | 
            +
                def forward(self, x, x_mask=None):
         | 
| 348 | 
            +
                    for c in self.convs:
         | 
| 349 | 
            +
                        xt = F.leaky_relu(x, LRELU_SLOPE)
         | 
| 350 | 
            +
                        if x_mask is not None:
         | 
| 351 | 
            +
                            xt = xt * x_mask
         | 
| 352 | 
            +
                        xt = c(xt)
         | 
| 353 | 
            +
                        x = xt + x
         | 
| 354 | 
            +
                    if x_mask is not None:
         | 
| 355 | 
            +
                        x = x * x_mask
         | 
| 356 | 
            +
                    return x
         | 
| 357 | 
            +
             | 
| 358 | 
            +
                def remove_weight_norm(self):
         | 
| 359 | 
            +
                    for l in self.convs:
         | 
| 360 | 
            +
                        remove_weight_norm(l)
         | 
| 361 | 
            +
             | 
| 362 | 
            +
             | 
| 363 | 
            +
            class Log(nn.Module):
         | 
| 364 | 
            +
                def forward(self, x, x_mask, reverse=False, **kwargs):
         | 
| 365 | 
            +
                    if not reverse:
         | 
| 366 | 
            +
                        y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
         | 
| 367 | 
            +
                        logdet = torch.sum(-y, [1, 2])
         | 
| 368 | 
            +
                        return y, logdet
         | 
| 369 | 
            +
                    else:
         | 
| 370 | 
            +
                        x = torch.exp(x) * x_mask
         | 
| 371 | 
            +
                        return x
         | 
| 372 | 
            +
             | 
| 373 | 
            +
             | 
| 374 | 
            +
            class Flip(nn.Module):
         | 
| 375 | 
            +
                def forward(self, x, *args, reverse=False, **kwargs):
         | 
| 376 | 
            +
                    x = torch.flip(x, [1])
         | 
| 377 | 
            +
                    if not reverse:
         | 
| 378 | 
            +
                        logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
         | 
| 379 | 
            +
                        return x, logdet
         | 
| 380 | 
            +
                    else:
         | 
| 381 | 
            +
                        return x
         | 
| 382 | 
            +
             | 
| 383 | 
            +
             | 
| 384 | 
            +
            class ElementwiseAffine(nn.Module):
         | 
| 385 | 
            +
                def __init__(self, channels):
         | 
| 386 | 
            +
                    super().__init__()
         | 
| 387 | 
            +
                    self.channels = channels
         | 
| 388 | 
            +
                    self.m = nn.Parameter(torch.zeros(channels, 1))
         | 
| 389 | 
            +
                    self.logs = nn.Parameter(torch.zeros(channels, 1))
         | 
| 390 | 
            +
             | 
| 391 | 
            +
                def forward(self, x, x_mask, reverse=False, **kwargs):
         | 
| 392 | 
            +
                    if not reverse:
         | 
| 393 | 
            +
                        y = self.m + torch.exp(self.logs) * x
         | 
| 394 | 
            +
                        y = y * x_mask
         | 
| 395 | 
            +
                        logdet = torch.sum(self.logs * x_mask, [1, 2])
         | 
| 396 | 
            +
                        return y, logdet
         | 
| 397 | 
            +
                    else:
         | 
| 398 | 
            +
                        x = (x - self.m) * torch.exp(-self.logs) * x_mask
         | 
| 399 | 
            +
                        return x
         | 
| 400 | 
            +
             | 
| 401 | 
            +
             | 
| 402 | 
            +
            class ResidualCouplingLayer(nn.Module):
         | 
| 403 | 
            +
                def __init__(
         | 
| 404 | 
            +
                    self,
         | 
| 405 | 
            +
                    channels,
         | 
| 406 | 
            +
                    hidden_channels,
         | 
| 407 | 
            +
                    kernel_size,
         | 
| 408 | 
            +
                    dilation_rate,
         | 
| 409 | 
            +
                    n_layers,
         | 
| 410 | 
            +
                    p_dropout=0,
         | 
| 411 | 
            +
                    gin_channels=0,
         | 
| 412 | 
            +
                    mean_only=False,
         | 
| 413 | 
            +
                ):
         | 
| 414 | 
            +
                    assert channels % 2 == 0, "channels should be divisible by 2"
         | 
| 415 | 
            +
                    super().__init__()
         | 
| 416 | 
            +
                    self.channels = channels
         | 
| 417 | 
            +
                    self.hidden_channels = hidden_channels
         | 
| 418 | 
            +
                    self.kernel_size = kernel_size
         | 
| 419 | 
            +
                    self.dilation_rate = dilation_rate
         | 
| 420 | 
            +
                    self.n_layers = n_layers
         | 
| 421 | 
            +
                    self.half_channels = channels // 2
         | 
| 422 | 
            +
                    self.mean_only = mean_only
         | 
| 423 | 
            +
             | 
| 424 | 
            +
                    self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
         | 
| 425 | 
            +
                    self.enc = WN(
         | 
| 426 | 
            +
                        hidden_channels,
         | 
| 427 | 
            +
                        kernel_size,
         | 
| 428 | 
            +
                        dilation_rate,
         | 
| 429 | 
            +
                        n_layers,
         | 
| 430 | 
            +
                        p_dropout=p_dropout,
         | 
| 431 | 
            +
                        gin_channels=gin_channels,
         | 
| 432 | 
            +
                    )
         | 
| 433 | 
            +
                    self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
         | 
| 434 | 
            +
                    self.post.weight.data.zero_()
         | 
| 435 | 
            +
                    self.post.bias.data.zero_()
         | 
| 436 | 
            +
             | 
| 437 | 
            +
                def forward(self, x, x_mask, g=None, reverse=False):
         | 
| 438 | 
            +
                    x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
         | 
| 439 | 
            +
                    h = self.pre(x0) * x_mask
         | 
| 440 | 
            +
                    h = self.enc(h, x_mask, g=g)
         | 
| 441 | 
            +
                    stats = self.post(h) * x_mask
         | 
| 442 | 
            +
                    if not self.mean_only:
         | 
| 443 | 
            +
                        m, logs = torch.split(stats, [self.half_channels] * 2, 1)
         | 
| 444 | 
            +
                    else:
         | 
| 445 | 
            +
                        m = stats
         | 
| 446 | 
            +
                        logs = torch.zeros_like(m)
         | 
| 447 | 
            +
             | 
| 448 | 
            +
                    if not reverse:
         | 
| 449 | 
            +
                        x1 = m + x1 * torch.exp(logs) * x_mask
         | 
| 450 | 
            +
                        x = torch.cat([x0, x1], 1)
         | 
| 451 | 
            +
                        logdet = torch.sum(logs, [1, 2])
         | 
| 452 | 
            +
                        return x, logdet
         | 
| 453 | 
            +
                    else:
         | 
| 454 | 
            +
                        x1 = (x1 - m) * torch.exp(-logs) * x_mask
         | 
| 455 | 
            +
                        x = torch.cat([x0, x1], 1)
         | 
| 456 | 
            +
                        return x
         | 
| 457 | 
            +
             | 
| 458 | 
            +
             | 
| 459 | 
            +
            class ConvFlow(nn.Module):
         | 
| 460 | 
            +
                def __init__(
         | 
| 461 | 
            +
                    self,
         | 
| 462 | 
            +
                    in_channels,
         | 
| 463 | 
            +
                    filter_channels,
         | 
| 464 | 
            +
                    kernel_size,
         | 
| 465 | 
            +
                    n_layers,
         | 
| 466 | 
            +
                    num_bins=10,
         | 
| 467 | 
            +
                    tail_bound=5.0,
         | 
| 468 | 
            +
                ):
         | 
| 469 | 
            +
                    super().__init__()
         | 
| 470 | 
            +
                    self.in_channels = in_channels
         | 
| 471 | 
            +
                    self.filter_channels = filter_channels
         | 
| 472 | 
            +
                    self.kernel_size = kernel_size
         | 
| 473 | 
            +
                    self.n_layers = n_layers
         | 
| 474 | 
            +
                    self.num_bins = num_bins
         | 
| 475 | 
            +
                    self.tail_bound = tail_bound
         | 
| 476 | 
            +
                    self.half_channels = in_channels // 2
         | 
| 477 | 
            +
             | 
| 478 | 
            +
                    self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
         | 
| 479 | 
            +
                    self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
         | 
| 480 | 
            +
                    self.proj = nn.Conv1d(
         | 
| 481 | 
            +
                        filter_channels, self.half_channels * (num_bins * 3 - 1), 1
         | 
| 482 | 
            +
                    )
         | 
| 483 | 
            +
                    self.proj.weight.data.zero_()
         | 
| 484 | 
            +
                    self.proj.bias.data.zero_()
         | 
| 485 | 
            +
             | 
| 486 | 
            +
                def forward(self, x, x_mask, g=None, reverse=False):
         | 
| 487 | 
            +
                    x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
         | 
| 488 | 
            +
                    h = self.pre(x0)
         | 
| 489 | 
            +
                    h = self.convs(h, x_mask, g=g)
         | 
| 490 | 
            +
                    h = self.proj(h) * x_mask
         | 
| 491 | 
            +
             | 
| 492 | 
            +
                    b, c, t = x0.shape
         | 
| 493 | 
            +
                    h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2)  # [b, cx?, t] -> [b, c, t, ?]
         | 
| 494 | 
            +
             | 
| 495 | 
            +
                    unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
         | 
| 496 | 
            +
                    unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
         | 
| 497 | 
            +
                        self.filter_channels
         | 
| 498 | 
            +
                    )
         | 
| 499 | 
            +
                    unnormalized_derivatives = h[..., 2 * self.num_bins :]
         | 
| 500 | 
            +
             | 
| 501 | 
            +
                    x1, logabsdet = piecewise_rational_quadratic_transform(
         | 
| 502 | 
            +
                        x1,
         | 
| 503 | 
            +
                        unnormalized_widths,
         | 
| 504 | 
            +
                        unnormalized_heights,
         | 
| 505 | 
            +
                        unnormalized_derivatives,
         | 
| 506 | 
            +
                        inverse=reverse,
         | 
| 507 | 
            +
                        tails="linear",
         | 
| 508 | 
            +
                        tail_bound=self.tail_bound,
         | 
| 509 | 
            +
                    )
         | 
| 510 | 
            +
             | 
| 511 | 
            +
                    x = torch.cat([x0, x1], 1) * x_mask
         | 
| 512 | 
            +
                    logdet = torch.sum(logabsdet * x_mask, [1, 2])
         | 
| 513 | 
            +
                    if not reverse:
         | 
| 514 | 
            +
                        return x, logdet
         | 
| 515 | 
            +
                    else:
         | 
| 516 | 
            +
                        return x
         | 
| 517 | 
            +
             | 
| 518 | 
            +
             | 
| 519 | 
            +
            class TransformerCouplingLayer(nn.Module):
         | 
| 520 | 
            +
                def __init__(
         | 
| 521 | 
            +
                    self,
         | 
| 522 | 
            +
                    channels,
         | 
| 523 | 
            +
                    hidden_channels,
         | 
| 524 | 
            +
                    kernel_size,
         | 
| 525 | 
            +
                    n_layers,
         | 
| 526 | 
            +
                    n_heads,
         | 
| 527 | 
            +
                    p_dropout=0,
         | 
| 528 | 
            +
                    filter_channels=0,
         | 
| 529 | 
            +
                    mean_only=False,
         | 
| 530 | 
            +
                    wn_sharing_parameter=None,
         | 
| 531 | 
            +
                    gin_channels=0,
         | 
| 532 | 
            +
                ):
         | 
| 533 | 
            +
                    assert n_layers == 3, n_layers
         | 
| 534 | 
            +
                    assert channels % 2 == 0, "channels should be divisible by 2"
         | 
| 535 | 
            +
                    super().__init__()
         | 
| 536 | 
            +
                    self.channels = channels
         | 
| 537 | 
            +
                    self.hidden_channels = hidden_channels
         | 
| 538 | 
            +
                    self.kernel_size = kernel_size
         | 
| 539 | 
            +
                    self.n_layers = n_layers
         | 
| 540 | 
            +
                    self.half_channels = channels // 2
         | 
| 541 | 
            +
                    self.mean_only = mean_only
         | 
| 542 | 
            +
             | 
| 543 | 
            +
                    self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
         | 
| 544 | 
            +
                    self.enc = (
         | 
| 545 | 
            +
                        Encoder(
         | 
| 546 | 
            +
                            hidden_channels,
         | 
| 547 | 
            +
                            filter_channels,
         | 
| 548 | 
            +
                            n_heads,
         | 
| 549 | 
            +
                            n_layers,
         | 
| 550 | 
            +
                            kernel_size,
         | 
| 551 | 
            +
                            p_dropout,
         | 
| 552 | 
            +
                            isflow=True,
         | 
| 553 | 
            +
                            gin_channels=gin_channels,
         | 
| 554 | 
            +
                        )
         | 
| 555 | 
            +
                        if wn_sharing_parameter is None
         | 
| 556 | 
            +
                        else wn_sharing_parameter
         | 
| 557 | 
            +
                    )
         | 
| 558 | 
            +
                    self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
         | 
| 559 | 
            +
                    self.post.weight.data.zero_()
         | 
| 560 | 
            +
                    self.post.bias.data.zero_()
         | 
| 561 | 
            +
             | 
| 562 | 
            +
                def forward(self, x, x_mask, g=None, reverse=False):
         | 
| 563 | 
            +
                    x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
         | 
| 564 | 
            +
                    h = self.pre(x0) * x_mask
         | 
| 565 | 
            +
                    h = self.enc(h, x_mask, g=g)
         | 
| 566 | 
            +
                    stats = self.post(h) * x_mask
         | 
| 567 | 
            +
                    if not self.mean_only:
         | 
| 568 | 
            +
                        m, logs = torch.split(stats, [self.half_channels] * 2, 1)
         | 
| 569 | 
            +
                    else:
         | 
| 570 | 
            +
                        m = stats
         | 
| 571 | 
            +
                        logs = torch.zeros_like(m)
         | 
| 572 | 
            +
             | 
| 573 | 
            +
                    if not reverse:
         | 
| 574 | 
            +
                        x1 = m + x1 * torch.exp(logs) * x_mask
         | 
| 575 | 
            +
                        x = torch.cat([x0, x1], 1)
         | 
| 576 | 
            +
                        logdet = torch.sum(logs, [1, 2])
         | 
| 577 | 
            +
                        return x, logdet
         | 
| 578 | 
            +
                    else:
         | 
| 579 | 
            +
                        x1 = (x1 - m) * torch.exp(-logs) * x_mask
         | 
| 580 | 
            +
                        x = torch.cat([x0, x1], 1)
         | 
| 581 | 
            +
                        return x
         | 
| 582 | 
            +
             | 
| 583 | 
            +
                    x1, logabsdet = piecewise_rational_quadratic_transform(
         | 
| 584 | 
            +
                        x1,
         | 
| 585 | 
            +
                        unnormalized_widths,
         | 
| 586 | 
            +
                        unnormalized_heights,
         | 
| 587 | 
            +
                        unnormalized_derivatives,
         | 
| 588 | 
            +
                        inverse=reverse,
         | 
| 589 | 
            +
                        tails="linear",
         | 
| 590 | 
            +
                        tail_bound=self.tail_bound,
         | 
| 591 | 
            +
                    )
         | 
| 592 | 
            +
             | 
| 593 | 
            +
                    x = torch.cat([x0, x1], 1) * x_mask
         | 
| 594 | 
            +
                    logdet = torch.sum(logabsdet * x_mask, [1, 2])
         | 
| 595 | 
            +
                    if not reverse:
         | 
| 596 | 
            +
                        return x, logdet
         | 
| 597 | 
            +
                    else:
         | 
| 598 | 
            +
                        return x
         | 
    	
        melo/split_utils.py
    ADDED
    
    | @@ -0,0 +1,131 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import re
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            import glob
         | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
            import soundfile as sf
         | 
| 6 | 
            +
            import torchaudio
         | 
| 7 | 
            +
            from txtsplit import txtsplit
         | 
| 8 | 
            +
            def split_sentence(text, min_len=10, language_str='EN'):
         | 
| 9 | 
            +
                if language_str in ['EN', 'FR', 'ES', 'SP', 'DE', 'RU']:
         | 
| 10 | 
            +
                    sentences = split_sentences_latin(text, min_len=min_len)
         | 
| 11 | 
            +
                else:
         | 
| 12 | 
            +
                    sentences = split_sentences_zh(text, min_len=min_len)
         | 
| 13 | 
            +
                return sentences
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            def split_sentences_latin(text, min_len=10):
         | 
| 16 | 
            +
                text = re.sub('[。!?;]', '.', text)
         | 
| 17 | 
            +
                text = re.sub('[,]', ',', text)
         | 
| 18 | 
            +
                text = re.sub('[“”]', '"', text)
         | 
| 19 | 
            +
                text = re.sub('[‘’]', "'", text)
         | 
| 20 | 
            +
                text = re.sub(r"[\<\>\(\)\[\]\"\«\»]+", "", text)
         | 
| 21 | 
            +
                return txtsplit(text, 512, 512)
         | 
| 22 | 
            +
                # 将文本中的换行符、空格和制表符替换为空格
         | 
| 23 | 
            +
                # text = re.sub('[\n\t ]+', ' ', text)
         | 
| 24 | 
            +
                # # 在标点符号后添加一个空格
         | 
| 25 | 
            +
                # text = re.sub('([,.!?;])', r'\1 $#!', text)
         | 
| 26 | 
            +
                # # 分隔句子并去除前后空格
         | 
| 27 | 
            +
                # sentences = [s.strip() for s in text.split('$#!')]
         | 
| 28 | 
            +
                # if len(sentences[-1]) == 0: del sentences[-1]
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                # new_sentences = []
         | 
| 31 | 
            +
                # new_sent = []
         | 
| 32 | 
            +
                # count_len = 0
         | 
| 33 | 
            +
                # for ind, sent in enumerate(sentences):
         | 
| 34 | 
            +
                #     # print(sent)
         | 
| 35 | 
            +
                #     new_sent.append(sent)
         | 
| 36 | 
            +
                #     count_len += len(sent.split(" "))
         | 
| 37 | 
            +
                #     if count_len > min_len or ind == len(sentences) - 1:
         | 
| 38 | 
            +
                #         count_len = 0
         | 
| 39 | 
            +
                #         new_sentences.append(' '.join(new_sent))
         | 
| 40 | 
            +
                #         new_sent = []
         | 
| 41 | 
            +
                # return merge_short_sentences_en(new_sentences)
         | 
| 42 | 
            +
             | 
| 43 | 
            +
            def split_sentences_zh(text, min_len=10):
         | 
| 44 | 
            +
                text = re.sub('[。!?;]', '.', text)
         | 
| 45 | 
            +
                text = re.sub('[,]', ',', text)
         | 
| 46 | 
            +
                # 将文本中的换行符、空格和制表符替换为空格
         | 
| 47 | 
            +
                text = re.sub('[\n\t ]+', ' ', text)
         | 
| 48 | 
            +
                # 在标点符号后添加一个空格
         | 
| 49 | 
            +
                text = re.sub('([,.!?;])', r'\1 $#!', text)
         | 
| 50 | 
            +
                # 分隔句子并去除前后空格
         | 
| 51 | 
            +
                # sentences = [s.strip() for s in re.split('(。|!|?|;)', text)]
         | 
| 52 | 
            +
                sentences = [s.strip() for s in text.split('$#!')]
         | 
| 53 | 
            +
                if len(sentences[-1]) == 0: del sentences[-1]
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                new_sentences = []
         | 
| 56 | 
            +
                new_sent = []
         | 
| 57 | 
            +
                count_len = 0
         | 
| 58 | 
            +
                for ind, sent in enumerate(sentences):
         | 
| 59 | 
            +
                    new_sent.append(sent)
         | 
| 60 | 
            +
                    count_len += len(sent)
         | 
| 61 | 
            +
                    if count_len > min_len or ind == len(sentences) - 1:
         | 
| 62 | 
            +
                        count_len = 0
         | 
| 63 | 
            +
                        new_sentences.append(' '.join(new_sent))
         | 
| 64 | 
            +
                        new_sent = []
         | 
| 65 | 
            +
                return merge_short_sentences_zh(new_sentences)
         | 
| 66 | 
            +
             | 
| 67 | 
            +
            def merge_short_sentences_en(sens):
         | 
| 68 | 
            +
                """Avoid short sentences by merging them with the following sentence.
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                Args:
         | 
| 71 | 
            +
                    List[str]: list of input sentences.
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                Returns:
         | 
| 74 | 
            +
                    List[str]: list of output sentences.
         | 
| 75 | 
            +
                """
         | 
| 76 | 
            +
                sens_out = []
         | 
| 77 | 
            +
                for s in sens:
         | 
| 78 | 
            +
                    # If the previous sentense is too short, merge them with
         | 
| 79 | 
            +
                    # the current sentence.
         | 
| 80 | 
            +
                    if len(sens_out) > 0 and len(sens_out[-1].split(" ")) <= 2:
         | 
| 81 | 
            +
                        sens_out[-1] = sens_out[-1] + " " + s
         | 
| 82 | 
            +
                    else:
         | 
| 83 | 
            +
                        sens_out.append(s)
         | 
| 84 | 
            +
                try:
         | 
| 85 | 
            +
                    if len(sens_out[-1].split(" ")) <= 2:
         | 
| 86 | 
            +
                        sens_out[-2] = sens_out[-2] + " " + sens_out[-1]
         | 
| 87 | 
            +
                        sens_out.pop(-1)
         | 
| 88 | 
            +
                except:
         | 
| 89 | 
            +
                    pass
         | 
| 90 | 
            +
                return sens_out
         | 
| 91 | 
            +
             | 
| 92 | 
            +
            def merge_short_sentences_zh(sens):
         | 
| 93 | 
            +
                # return sens
         | 
| 94 | 
            +
                """Avoid short sentences by merging them with the following sentence.
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                Args:
         | 
| 97 | 
            +
                    List[str]: list of input sentences.
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                Returns:
         | 
| 100 | 
            +
                    List[str]: list of output sentences.
         | 
| 101 | 
            +
                """
         | 
| 102 | 
            +
                sens_out = []
         | 
| 103 | 
            +
                for s in sens:
         | 
| 104 | 
            +
                    # If the previous sentense is too short, merge them with
         | 
| 105 | 
            +
                    # the current sentence.
         | 
| 106 | 
            +
                    if len(sens_out) > 0 and len(sens_out[-1]) <= 2:
         | 
| 107 | 
            +
                        sens_out[-1] = sens_out[-1] + " " + s
         | 
| 108 | 
            +
                    else:
         | 
| 109 | 
            +
                        sens_out.append(s)
         | 
| 110 | 
            +
                try:
         | 
| 111 | 
            +
                    if len(sens_out[-1]) <= 2:
         | 
| 112 | 
            +
                        sens_out[-2] = sens_out[-2] + " " + sens_out[-1]
         | 
| 113 | 
            +
                        sens_out.pop(-1)
         | 
| 114 | 
            +
                except:
         | 
| 115 | 
            +
                    pass
         | 
| 116 | 
            +
                return sens_out
         | 
| 117 | 
            +
             | 
| 118 | 
            +
             | 
| 119 | 
            +
            if __name__ == '__main__':
         | 
| 120 | 
            +
                zh_text = "好的,我来给你讲一个故事吧。从前有一个小姑娘,她叫做小红。小红非常喜欢在森林里玩耍,她经常会和她的小伙伴们一起去探险。有一天,小红和她的小伙伴们走到了森林深处,突然遇到了一只凶猛的野兽。小红的小伙伴们都吓得不敢动弹,但是小红并没有被吓倒,她勇敢地走向野兽,用她的智慧和勇气成功地制服了野兽,保护了她的小伙伴们。从那以后,小红变得更加勇敢和自信,成为了她小伙伴们心中的英雄。"
         | 
| 121 | 
            +
                en_text = "I didn’t know what to do. I said please kill her because it would be better than being kidnapped,” Ben, whose surname CNN is not using for security concerns, said on Wednesday. “It’s a nightmare. I said ‘please kill her, don’t take her there.’"
         | 
| 122 | 
            +
                sp_text = "¡Claro! ¿En qué tema te gustaría que te hable en español? Puedo proporcionarte información o conversar contigo sobre una amplia variedad de temas, desde cultura y comida hasta viajes y tecnología. ¿Tienes alguna preferencia en particular?"
         | 
| 123 | 
            +
                fr_text = "Bien sûr ! En quelle matière voudriez-vous que je vous parle en français ? Je peux vous fournir des informations ou discuter avec vous sur une grande variété de sujets, que ce soit la culture, la nourriture, les voyages ou la technologie. Avez-vous une préférence particulière ?"
         | 
| 124 | 
            +
                de_text = 'Es war das Wichtigste was wir sichern wollten da es keine Möglichkeit gab eine 20 Megatonnen- H- Bombe ab zu werfen von einem 30, C124.'
         | 
| 125 | 
            +
                ru_text = 'Но он был во многом, как-бы, всё равно что сын плантатора, так как являлся сыном человека, у которого было в собственности много чего.'
         | 
| 126 | 
            +
                print(split_sentence(zh_text, language_str='ZH'))
         | 
| 127 | 
            +
                print(split_sentence(en_text, language_str='EN'))
         | 
| 128 | 
            +
                print(split_sentence(sp_text, language_str='SP'))
         | 
| 129 | 
            +
                print(split_sentence(fr_text, language_str='FR'))
         | 
| 130 | 
            +
                print(split_sentence(de_text, language_str='DE'))
         | 
| 131 | 
            +
                print(split_sentence(ru_text, language_str='RU'))
         | 
    	
        melo/text/__init__.py
    ADDED
    
    | @@ -0,0 +1,35 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .symbols import *
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            _symbol_to_id = {s: i for i, s in enumerate(symbols)}
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            def cleaned_text_to_sequence(cleaned_text, tones, language, symbol_to_id=None):
         | 
| 8 | 
            +
                """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
         | 
| 9 | 
            +
                Args:
         | 
| 10 | 
            +
                  text: string to convert to a sequence
         | 
| 11 | 
            +
                Returns:
         | 
| 12 | 
            +
                  List of integers corresponding to the symbols in the text
         | 
| 13 | 
            +
                """
         | 
| 14 | 
            +
                symbol_to_id_map = symbol_to_id if symbol_to_id else _symbol_to_id
         | 
| 15 | 
            +
                phones = [symbol_to_id_map[symbol] for symbol in cleaned_text]
         | 
| 16 | 
            +
                tone_start = language_tone_start_map[language]
         | 
| 17 | 
            +
                tones = [i + tone_start for i in tones]
         | 
| 18 | 
            +
                lang_id = language_id_map[language]
         | 
| 19 | 
            +
                lang_ids = [lang_id for i in phones]
         | 
| 20 | 
            +
                return phones, tones, lang_ids
         | 
| 21 | 
            +
             | 
| 22 | 
            +
             | 
| 23 | 
            +
            def get_bert(norm_text, word2ph, language, device):
         | 
| 24 | 
            +
                from .chinese_bert import get_bert_feature as zh_bert
         | 
| 25 | 
            +
                from .english_bert import get_bert_feature as en_bert
         | 
| 26 | 
            +
                from .japanese_bert import get_bert_feature as jp_bert
         | 
| 27 | 
            +
                from .chinese_mix import get_bert_feature as zh_mix_en_bert
         | 
| 28 | 
            +
                from .spanish_bert import get_bert_feature as sp_bert
         | 
| 29 | 
            +
                from .french_bert import get_bert_feature as fr_bert
         | 
| 30 | 
            +
                from .korean import get_bert_feature as kr_bert
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                lang_bert_func_map = {"ZH": zh_bert, "EN": en_bert, "JP": jp_bert, 'ZH_MIX_EN': zh_mix_en_bert, 
         | 
| 33 | 
            +
                                      'FR': fr_bert, 'SP': sp_bert, 'ES': sp_bert, "KR": kr_bert}
         | 
| 34 | 
            +
                bert = lang_bert_func_map[language](norm_text, word2ph, device)
         | 
| 35 | 
            +
                return bert
         | 
    	
        melo/text/chinese.py
    ADDED
    
    | @@ -0,0 +1,199 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import re
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import cn2an
         | 
| 5 | 
            +
            from pypinyin import lazy_pinyin, Style
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from .symbols import punctuation
         | 
| 8 | 
            +
            from .tone_sandhi import ToneSandhi
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            current_file_path = os.path.dirname(__file__)
         | 
| 11 | 
            +
            pinyin_to_symbol_map = {
         | 
| 12 | 
            +
                line.split("\t")[0]: line.strip().split("\t")[1]
         | 
| 13 | 
            +
                for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines()
         | 
| 14 | 
            +
            }
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            import jieba.posseg as psg
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            rep_map = {
         | 
| 20 | 
            +
                ":": ",",
         | 
| 21 | 
            +
                ";": ",",
         | 
| 22 | 
            +
                ",": ",",
         | 
| 23 | 
            +
                "。": ".",
         | 
| 24 | 
            +
                "!": "!",
         | 
| 25 | 
            +
                "?": "?",
         | 
| 26 | 
            +
                "\n": ".",
         | 
| 27 | 
            +
                "·": ",",
         | 
| 28 | 
            +
                "、": ",",
         | 
| 29 | 
            +
                "...": "…",
         | 
| 30 | 
            +
                "$": ".",
         | 
| 31 | 
            +
                "“": "'",
         | 
| 32 | 
            +
                "”": "'",
         | 
| 33 | 
            +
                "‘": "'",
         | 
| 34 | 
            +
                "’": "'",
         | 
| 35 | 
            +
                "(": "'",
         | 
| 36 | 
            +
                ")": "'",
         | 
| 37 | 
            +
                "(": "'",
         | 
| 38 | 
            +
                ")": "'",
         | 
| 39 | 
            +
                "《": "'",
         | 
| 40 | 
            +
                "》": "'",
         | 
| 41 | 
            +
                "【": "'",
         | 
| 42 | 
            +
                "】": "'",
         | 
| 43 | 
            +
                "[": "'",
         | 
| 44 | 
            +
                "]": "'",
         | 
| 45 | 
            +
                "—": "-",
         | 
| 46 | 
            +
                "~": "-",
         | 
| 47 | 
            +
                "~": "-",
         | 
| 48 | 
            +
                "「": "'",
         | 
| 49 | 
            +
                "」": "'",
         | 
| 50 | 
            +
            }
         | 
| 51 | 
            +
             | 
| 52 | 
            +
            tone_modifier = ToneSandhi()
         | 
| 53 | 
            +
             | 
| 54 | 
            +
             | 
| 55 | 
            +
            def replace_punctuation(text):
         | 
| 56 | 
            +
                text = text.replace("嗯", "恩").replace("呣", "母")
         | 
| 57 | 
            +
                pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                replaced_text = re.sub(
         | 
| 62 | 
            +
                    r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text
         | 
| 63 | 
            +
                )
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                return replaced_text
         | 
| 66 | 
            +
             | 
| 67 | 
            +
             | 
| 68 | 
            +
            def g2p(text):
         | 
| 69 | 
            +
                pattern = r"(?<=[{0}])\s*".format("".join(punctuation))
         | 
| 70 | 
            +
                sentences = [i for i in re.split(pattern, text) if i.strip() != ""]
         | 
| 71 | 
            +
                phones, tones, word2ph = _g2p(sentences)
         | 
| 72 | 
            +
                assert sum(word2ph) == len(phones)
         | 
| 73 | 
            +
                assert len(word2ph) == len(text)  # Sometimes it will crash,you can add a try-catch.
         | 
| 74 | 
            +
                phones = ["_"] + phones + ["_"]
         | 
| 75 | 
            +
                tones = [0] + tones + [0]
         | 
| 76 | 
            +
                word2ph = [1] + word2ph + [1]
         | 
| 77 | 
            +
                return phones, tones, word2ph
         | 
| 78 | 
            +
             | 
| 79 | 
            +
             | 
| 80 | 
            +
            def _get_initials_finals(word):
         | 
| 81 | 
            +
                initials = []
         | 
| 82 | 
            +
                finals = []
         | 
| 83 | 
            +
                orig_initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS)
         | 
| 84 | 
            +
                orig_finals = lazy_pinyin(
         | 
| 85 | 
            +
                    word, neutral_tone_with_five=True, style=Style.FINALS_TONE3
         | 
| 86 | 
            +
                )
         | 
| 87 | 
            +
                for c, v in zip(orig_initials, orig_finals):
         | 
| 88 | 
            +
                    initials.append(c)
         | 
| 89 | 
            +
                    finals.append(v)
         | 
| 90 | 
            +
                return initials, finals
         | 
| 91 | 
            +
             | 
| 92 | 
            +
             | 
| 93 | 
            +
            def _g2p(segments):
         | 
| 94 | 
            +
                phones_list = []
         | 
| 95 | 
            +
                tones_list = []
         | 
| 96 | 
            +
                word2ph = []
         | 
| 97 | 
            +
                for seg in segments:
         | 
| 98 | 
            +
                    # Replace all English words in the sentence
         | 
| 99 | 
            +
                    seg = re.sub("[a-zA-Z]+", "", seg)
         | 
| 100 | 
            +
                    seg_cut = psg.lcut(seg)
         | 
| 101 | 
            +
                    initials = []
         | 
| 102 | 
            +
                    finals = []
         | 
| 103 | 
            +
                    seg_cut = tone_modifier.pre_merge_for_modify(seg_cut)
         | 
| 104 | 
            +
                    for word, pos in seg_cut:
         | 
| 105 | 
            +
                        if pos == "eng":
         | 
| 106 | 
            +
                            import pdb; pdb.set_trace()
         | 
| 107 | 
            +
                            continue
         | 
| 108 | 
            +
                        sub_initials, sub_finals = _get_initials_finals(word)
         | 
| 109 | 
            +
                        sub_finals = tone_modifier.modified_tone(word, pos, sub_finals)
         | 
| 110 | 
            +
                        initials.append(sub_initials)
         | 
| 111 | 
            +
                        finals.append(sub_finals)
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                        # assert len(sub_initials) == len(sub_finals) == len(word)
         | 
| 114 | 
            +
                    initials = sum(initials, [])
         | 
| 115 | 
            +
                    finals = sum(finals, [])
         | 
| 116 | 
            +
                    #
         | 
| 117 | 
            +
                    for c, v in zip(initials, finals):
         | 
| 118 | 
            +
                        raw_pinyin = c + v
         | 
| 119 | 
            +
                        # NOTE: post process for pypinyin outputs
         | 
| 120 | 
            +
                        # we discriminate i, ii and iii
         | 
| 121 | 
            +
                        if c == v:
         | 
| 122 | 
            +
                            assert c in punctuation
         | 
| 123 | 
            +
                            phone = [c]
         | 
| 124 | 
            +
                            tone = "0"
         | 
| 125 | 
            +
                            word2ph.append(1)
         | 
| 126 | 
            +
                        else:
         | 
| 127 | 
            +
                            v_without_tone = v[:-1]
         | 
| 128 | 
            +
                            tone = v[-1]
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                            pinyin = c + v_without_tone
         | 
| 131 | 
            +
                            assert tone in "12345"
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                            if c:
         | 
| 134 | 
            +
                                # 多音节
         | 
| 135 | 
            +
                                v_rep_map = {
         | 
| 136 | 
            +
                                    "uei": "ui",
         | 
| 137 | 
            +
                                    "iou": "iu",
         | 
| 138 | 
            +
                                    "uen": "un",
         | 
| 139 | 
            +
                                }
         | 
| 140 | 
            +
                                if v_without_tone in v_rep_map.keys():
         | 
| 141 | 
            +
                                    pinyin = c + v_rep_map[v_without_tone]
         | 
| 142 | 
            +
                            else:
         | 
| 143 | 
            +
                                # 单音节
         | 
| 144 | 
            +
                                pinyin_rep_map = {
         | 
| 145 | 
            +
                                    "ing": "ying",
         | 
| 146 | 
            +
                                    "i": "yi",
         | 
| 147 | 
            +
                                    "in": "yin",
         | 
| 148 | 
            +
                                    "u": "wu",
         | 
| 149 | 
            +
                                }
         | 
| 150 | 
            +
                                if pinyin in pinyin_rep_map.keys():
         | 
| 151 | 
            +
                                    pinyin = pinyin_rep_map[pinyin]
         | 
| 152 | 
            +
                                else:
         | 
| 153 | 
            +
                                    single_rep_map = {
         | 
| 154 | 
            +
                                        "v": "yu",
         | 
| 155 | 
            +
                                        "e": "e",
         | 
| 156 | 
            +
                                        "i": "y",
         | 
| 157 | 
            +
                                        "u": "w",
         | 
| 158 | 
            +
                                    }
         | 
| 159 | 
            +
                                    if pinyin[0] in single_rep_map.keys():
         | 
| 160 | 
            +
                                        pinyin = single_rep_map[pinyin[0]] + pinyin[1:]
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                            assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, seg, raw_pinyin)
         | 
| 163 | 
            +
                            phone = pinyin_to_symbol_map[pinyin].split(" ")
         | 
| 164 | 
            +
                            word2ph.append(len(phone))
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                        phones_list += phone
         | 
| 167 | 
            +
                        tones_list += [int(tone)] * len(phone)
         | 
| 168 | 
            +
                return phones_list, tones_list, word2ph
         | 
| 169 | 
            +
             | 
| 170 | 
            +
             | 
| 171 | 
            +
            def text_normalize(text):
         | 
| 172 | 
            +
                numbers = re.findall(r"\d+(?:\.?\d+)?", text)
         | 
| 173 | 
            +
                for number in numbers:
         | 
| 174 | 
            +
                    text = text.replace(number, cn2an.an2cn(number), 1)
         | 
| 175 | 
            +
                text = replace_punctuation(text)
         | 
| 176 | 
            +
                return text
         | 
| 177 | 
            +
             | 
| 178 | 
            +
             | 
| 179 | 
            +
            def get_bert_feature(text, word2ph, device=None):
         | 
| 180 | 
            +
                from text import chinese_bert
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                return chinese_bert.get_bert_feature(text, word2ph, device=device)
         | 
| 183 | 
            +
             | 
| 184 | 
            +
             | 
| 185 | 
            +
            if __name__ == "__main__":
         | 
| 186 | 
            +
                from text.chinese_bert import get_bert_feature
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                text = "啊!chemistry 但是《原神》是由,米哈\游自主,  [研发]的一款全.新开放世界.冒险游戏"
         | 
| 189 | 
            +
                text = text_normalize(text)
         | 
| 190 | 
            +
                print(text)
         | 
| 191 | 
            +
                phones, tones, word2ph = g2p(text)
         | 
| 192 | 
            +
                bert = get_bert_feature(text, word2ph)
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                print(phones, tones, word2ph, bert.shape)
         | 
| 195 | 
            +
             | 
| 196 | 
            +
             | 
| 197 | 
            +
            # # 示例用法
         | 
| 198 | 
            +
            # text = "这是一个示例文本:,你好!这是一个测试...."
         | 
| 199 | 
            +
            # print(g2p_paddle(text))  # 输出: 这是一个示例文本你好这是一个测试
         | 
    	
        melo/text/chinese_bert.py
    ADDED
    
    | @@ -0,0 +1,107 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import sys
         | 
| 3 | 
            +
            from transformers import AutoTokenizer, AutoModelForMaskedLM
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            # model_id = 'hfl/chinese-roberta-wwm-ext-large'
         | 
| 7 | 
            +
            local_path = "./bert/chinese-roberta-wwm-ext-large"
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            tokenizers = {}
         | 
| 11 | 
            +
            models = {}
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            def get_bert_feature(text, word2ph, device=None, model_id='hfl/chinese-roberta-wwm-ext-large'):
         | 
| 14 | 
            +
                if model_id not in models:
         | 
| 15 | 
            +
                    models[model_id] = AutoModelForMaskedLM.from_pretrained(
         | 
| 16 | 
            +
                        model_id
         | 
| 17 | 
            +
                    ).to(device)
         | 
| 18 | 
            +
                    tokenizers[model_id] = AutoTokenizer.from_pretrained(model_id)
         | 
| 19 | 
            +
                model = models[model_id]
         | 
| 20 | 
            +
                tokenizer = tokenizers[model_id]
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                if (
         | 
| 23 | 
            +
                    sys.platform == "darwin"
         | 
| 24 | 
            +
                    and torch.backends.mps.is_available()
         | 
| 25 | 
            +
                    and device == "cpu"
         | 
| 26 | 
            +
                ):
         | 
| 27 | 
            +
                    device = "mps"
         | 
| 28 | 
            +
                if not device:
         | 
| 29 | 
            +
                    device = "cuda"
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                with torch.no_grad():
         | 
| 32 | 
            +
                    inputs = tokenizer(text, return_tensors="pt")
         | 
| 33 | 
            +
                    for i in inputs:
         | 
| 34 | 
            +
                        inputs[i] = inputs[i].to(device)
         | 
| 35 | 
            +
                    res = model(**inputs, output_hidden_states=True)
         | 
| 36 | 
            +
                    res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
         | 
| 37 | 
            +
                # import pdb; pdb.set_trace()
         | 
| 38 | 
            +
                # assert len(word2ph) == len(text) + 2
         | 
| 39 | 
            +
                word2phone = word2ph
         | 
| 40 | 
            +
                phone_level_feature = []
         | 
| 41 | 
            +
                for i in range(len(word2phone)):
         | 
| 42 | 
            +
                    repeat_feature = res[i].repeat(word2phone[i], 1)
         | 
| 43 | 
            +
                    phone_level_feature.append(repeat_feature)
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                phone_level_feature = torch.cat(phone_level_feature, dim=0)
         | 
| 46 | 
            +
                return phone_level_feature.T
         | 
| 47 | 
            +
             | 
| 48 | 
            +
             | 
| 49 | 
            +
            if __name__ == "__main__":
         | 
| 50 | 
            +
                import torch
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                word_level_feature = torch.rand(38, 1024)  # 12个词,每个词1024维特征
         | 
| 53 | 
            +
                word2phone = [
         | 
| 54 | 
            +
                    1,
         | 
| 55 | 
            +
                    2,
         | 
| 56 | 
            +
                    1,
         | 
| 57 | 
            +
                    2,
         | 
| 58 | 
            +
                    2,
         | 
| 59 | 
            +
                    1,
         | 
| 60 | 
            +
                    2,
         | 
| 61 | 
            +
                    2,
         | 
| 62 | 
            +
                    1,
         | 
| 63 | 
            +
                    2,
         | 
| 64 | 
            +
                    2,
         | 
| 65 | 
            +
                    1,
         | 
| 66 | 
            +
                    2,
         | 
| 67 | 
            +
                    2,
         | 
| 68 | 
            +
                    2,
         | 
| 69 | 
            +
                    2,
         | 
| 70 | 
            +
                    2,
         | 
| 71 | 
            +
                    1,
         | 
| 72 | 
            +
                    1,
         | 
| 73 | 
            +
                    2,
         | 
| 74 | 
            +
                    2,
         | 
| 75 | 
            +
                    1,
         | 
| 76 | 
            +
                    2,
         | 
| 77 | 
            +
                    2,
         | 
| 78 | 
            +
                    2,
         | 
| 79 | 
            +
                    2,
         | 
| 80 | 
            +
                    1,
         | 
| 81 | 
            +
                    2,
         | 
| 82 | 
            +
                    2,
         | 
| 83 | 
            +
                    2,
         | 
| 84 | 
            +
                    2,
         | 
| 85 | 
            +
                    2,
         | 
| 86 | 
            +
                    1,
         | 
| 87 | 
            +
                    2,
         | 
| 88 | 
            +
                    2,
         | 
| 89 | 
            +
                    2,
         | 
| 90 | 
            +
                    2,
         | 
| 91 | 
            +
                    1,
         | 
| 92 | 
            +
                ]
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                # 计算总帧数
         | 
| 95 | 
            +
                total_frames = sum(word2phone)
         | 
| 96 | 
            +
                print(word_level_feature.shape)
         | 
| 97 | 
            +
                print(word2phone)
         | 
| 98 | 
            +
                phone_level_feature = []
         | 
| 99 | 
            +
                for i in range(len(word2phone)):
         | 
| 100 | 
            +
                    print(word_level_feature[i].shape)
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    # 对每个词重复word2phone[i]次
         | 
| 103 | 
            +
                    repeat_feature = word_level_feature[i].repeat(word2phone[i], 1)
         | 
| 104 | 
            +
                    phone_level_feature.append(repeat_feature)
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                phone_level_feature = torch.cat(phone_level_feature, dim=0)
         | 
| 107 | 
            +
                print(phone_level_feature.shape)  # torch.Size([36, 1024])
         | 
    	
        melo/text/chinese_mix.py
    ADDED
    
    | @@ -0,0 +1,253 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import re
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import cn2an
         | 
| 5 | 
            +
            from pypinyin import lazy_pinyin, Style
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            # from text.symbols import punctuation
         | 
| 8 | 
            +
            from .symbols import language_tone_start_map
         | 
| 9 | 
            +
            from .tone_sandhi import ToneSandhi
         | 
| 10 | 
            +
            from .english import g2p as g2p_en
         | 
| 11 | 
            +
            from transformers import AutoTokenizer
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            punctuation = ["!", "?", "…", ",", ".", "'", "-"]
         | 
| 14 | 
            +
            current_file_path = os.path.dirname(__file__)
         | 
| 15 | 
            +
            pinyin_to_symbol_map = {
         | 
| 16 | 
            +
                line.split("\t")[0]: line.strip().split("\t")[1]
         | 
| 17 | 
            +
                for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines()
         | 
| 18 | 
            +
            }
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            import jieba.posseg as psg
         | 
| 21 | 
            +
             | 
| 22 | 
            +
             | 
| 23 | 
            +
            rep_map = {
         | 
| 24 | 
            +
                ":": ",",
         | 
| 25 | 
            +
                ";": ",",
         | 
| 26 | 
            +
                ",": ",",
         | 
| 27 | 
            +
                "。": ".",
         | 
| 28 | 
            +
                "!": "!",
         | 
| 29 | 
            +
                "?": "?",
         | 
| 30 | 
            +
                "\n": ".",
         | 
| 31 | 
            +
                "·": ",",
         | 
| 32 | 
            +
                "、": ",",
         | 
| 33 | 
            +
                "...": "…",
         | 
| 34 | 
            +
                "$": ".",
         | 
| 35 | 
            +
                "“": "'",
         | 
| 36 | 
            +
                "”": "'",
         | 
| 37 | 
            +
                "‘": "'",
         | 
| 38 | 
            +
                "’": "'",
         | 
| 39 | 
            +
                "(": "'",
         | 
| 40 | 
            +
                ")": "'",
         | 
| 41 | 
            +
                "(": "'",
         | 
| 42 | 
            +
                ")": "'",
         | 
| 43 | 
            +
                "《": "'",
         | 
| 44 | 
            +
                "》": "'",
         | 
| 45 | 
            +
                "【": "'",
         | 
| 46 | 
            +
                "】": "'",
         | 
| 47 | 
            +
                "[": "'",
         | 
| 48 | 
            +
                "]": "'",
         | 
| 49 | 
            +
                "—": "-",
         | 
| 50 | 
            +
                "~": "-",
         | 
| 51 | 
            +
                "~": "-",
         | 
| 52 | 
            +
                "「": "'",
         | 
| 53 | 
            +
                "」": "'",
         | 
| 54 | 
            +
            }
         | 
| 55 | 
            +
             | 
| 56 | 
            +
            tone_modifier = ToneSandhi()
         | 
| 57 | 
            +
             | 
| 58 | 
            +
             | 
| 59 | 
            +
            def replace_punctuation(text):
         | 
| 60 | 
            +
                text = text.replace("嗯", "恩").replace("呣", "母")
         | 
| 61 | 
            +
                pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
         | 
| 62 | 
            +
                replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
         | 
| 63 | 
            +
                replaced_text = re.sub(r"[^\u4e00-\u9fa5_a-zA-Z\s" + "".join(punctuation) + r"]+", "", replaced_text)
         | 
| 64 | 
            +
                replaced_text = re.sub(r"[\s]+", " ", replaced_text)
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                return replaced_text
         | 
| 67 | 
            +
             | 
| 68 | 
            +
             | 
| 69 | 
            +
            def g2p(text, impl='v2'):
         | 
| 70 | 
            +
                pattern = r"(?<=[{0}])\s*".format("".join(punctuation))
         | 
| 71 | 
            +
                sentences = [i for i in re.split(pattern, text) if i.strip() != ""]
         | 
| 72 | 
            +
                if impl == 'v1':
         | 
| 73 | 
            +
                    _func = _g2p
         | 
| 74 | 
            +
                elif impl == 'v2':
         | 
| 75 | 
            +
                    _func = _g2p_v2
         | 
| 76 | 
            +
                else:
         | 
| 77 | 
            +
                    raise NotImplementedError()
         | 
| 78 | 
            +
                phones, tones, word2ph = _func(sentences)
         | 
| 79 | 
            +
                assert sum(word2ph) == len(phones)
         | 
| 80 | 
            +
                # assert len(word2ph) == len(text)  # Sometimes it will crash,you can add a try-catch.
         | 
| 81 | 
            +
                phones = ["_"] + phones + ["_"]
         | 
| 82 | 
            +
                tones = [0] + tones + [0]
         | 
| 83 | 
            +
                word2ph = [1] + word2ph + [1]
         | 
| 84 | 
            +
                return phones, tones, word2ph
         | 
| 85 | 
            +
             | 
| 86 | 
            +
             | 
| 87 | 
            +
            def _get_initials_finals(word):
         | 
| 88 | 
            +
                initials = []
         | 
| 89 | 
            +
                finals = []
         | 
| 90 | 
            +
                orig_initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS)
         | 
| 91 | 
            +
                orig_finals = lazy_pinyin(
         | 
| 92 | 
            +
                    word, neutral_tone_with_five=True, style=Style.FINALS_TONE3
         | 
| 93 | 
            +
                )
         | 
| 94 | 
            +
                for c, v in zip(orig_initials, orig_finals):
         | 
| 95 | 
            +
                    initials.append(c)
         | 
| 96 | 
            +
                    finals.append(v)
         | 
| 97 | 
            +
                return initials, finals
         | 
| 98 | 
            +
             | 
| 99 | 
            +
            model_id = 'bert-base-multilingual-uncased'
         | 
| 100 | 
            +
            tokenizer = AutoTokenizer.from_pretrained(model_id)
         | 
| 101 | 
            +
            def _g2p(segments):
         | 
| 102 | 
            +
                phones_list = []
         | 
| 103 | 
            +
                tones_list = []
         | 
| 104 | 
            +
                word2ph = []
         | 
| 105 | 
            +
                for seg in segments:
         | 
| 106 | 
            +
                    # Replace all English words in the sentence
         | 
| 107 | 
            +
                    # seg = re.sub("[a-zA-Z]+", "", seg)
         | 
| 108 | 
            +
                    seg_cut = psg.lcut(seg)
         | 
| 109 | 
            +
                    initials = []
         | 
| 110 | 
            +
                    finals = []
         | 
| 111 | 
            +
                    seg_cut = tone_modifier.pre_merge_for_modify(seg_cut)
         | 
| 112 | 
            +
                    for word, pos in seg_cut:
         | 
| 113 | 
            +
                        if pos == "eng":
         | 
| 114 | 
            +
                            initials.append(['EN_WORD'])
         | 
| 115 | 
            +
                            finals.append([word])
         | 
| 116 | 
            +
                        else:
         | 
| 117 | 
            +
                            sub_initials, sub_finals = _get_initials_finals(word)
         | 
| 118 | 
            +
                            sub_finals = tone_modifier.modified_tone(word, pos, sub_finals)
         | 
| 119 | 
            +
                            initials.append(sub_initials)
         | 
| 120 | 
            +
                            finals.append(sub_finals)
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                        # assert len(sub_initials) == len(sub_finals) == len(word)
         | 
| 123 | 
            +
                    initials = sum(initials, [])
         | 
| 124 | 
            +
                    finals = sum(finals, [])
         | 
| 125 | 
            +
                    #
         | 
| 126 | 
            +
                    for c, v in zip(initials, finals):
         | 
| 127 | 
            +
                        if c == 'EN_WORD':
         | 
| 128 | 
            +
                            tokenized_en = tokenizer.tokenize(v)
         | 
| 129 | 
            +
                            phones_en, tones_en, word2ph_en = g2p_en(text=None, pad_start_end=False, tokenized=tokenized_en)
         | 
| 130 | 
            +
                            # apply offset to tones_en
         | 
| 131 | 
            +
                            tones_en = [t + language_tone_start_map['EN'] for t in tones_en]
         | 
| 132 | 
            +
                            phones_list += phones_en
         | 
| 133 | 
            +
                            tones_list += tones_en
         | 
| 134 | 
            +
                            word2ph += word2ph_en
         | 
| 135 | 
            +
                        else:
         | 
| 136 | 
            +
                            raw_pinyin = c + v
         | 
| 137 | 
            +
                            # NOTE: post process for pypinyin outputs
         | 
| 138 | 
            +
                            # we discriminate i, ii and iii
         | 
| 139 | 
            +
                            if c == v:
         | 
| 140 | 
            +
                                assert c in punctuation
         | 
| 141 | 
            +
                                phone = [c]
         | 
| 142 | 
            +
                                tone = "0"
         | 
| 143 | 
            +
                                word2ph.append(1)
         | 
| 144 | 
            +
                            else:
         | 
| 145 | 
            +
                                v_without_tone = v[:-1]
         | 
| 146 | 
            +
                                tone = v[-1]
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                                pinyin = c + v_without_tone
         | 
| 149 | 
            +
                                assert tone in "12345"
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                                if c:
         | 
| 152 | 
            +
                                    # 多音节
         | 
| 153 | 
            +
                                    v_rep_map = {
         | 
| 154 | 
            +
                                        "uei": "ui",
         | 
| 155 | 
            +
                                        "iou": "iu",
         | 
| 156 | 
            +
                                        "uen": "un",
         | 
| 157 | 
            +
                                    }
         | 
| 158 | 
            +
                                    if v_without_tone in v_rep_map.keys():
         | 
| 159 | 
            +
                                        pinyin = c + v_rep_map[v_without_tone]
         | 
| 160 | 
            +
                                else:
         | 
| 161 | 
            +
                                    # 单音节
         | 
| 162 | 
            +
                                    pinyin_rep_map = {
         | 
| 163 | 
            +
                                        "ing": "ying",
         | 
| 164 | 
            +
                                        "i": "yi",
         | 
| 165 | 
            +
                                        "in": "yin",
         | 
| 166 | 
            +
                                        "u": "wu",
         | 
| 167 | 
            +
                                    }
         | 
| 168 | 
            +
                                    if pinyin in pinyin_rep_map.keys():
         | 
| 169 | 
            +
                                        pinyin = pinyin_rep_map[pinyin]
         | 
| 170 | 
            +
                                    else:
         | 
| 171 | 
            +
                                        single_rep_map = {
         | 
| 172 | 
            +
                                            "v": "yu",
         | 
| 173 | 
            +
                                            "e": "e",
         | 
| 174 | 
            +
                                            "i": "y",
         | 
| 175 | 
            +
                                            "u": "w",
         | 
| 176 | 
            +
                                        }
         | 
| 177 | 
            +
                                        if pinyin[0] in single_rep_map.keys():
         | 
| 178 | 
            +
                                            pinyin = single_rep_map[pinyin[0]] + pinyin[1:]
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                                assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, seg, raw_pinyin)
         | 
| 181 | 
            +
                                phone = pinyin_to_symbol_map[pinyin].split(" ")
         | 
| 182 | 
            +
                                word2ph.append(len(phone))
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                            phones_list += phone
         | 
| 185 | 
            +
                            tones_list += [int(tone)] * len(phone)
         | 
| 186 | 
            +
                return phones_list, tones_list, word2ph
         | 
| 187 | 
            +
             | 
| 188 | 
            +
             | 
| 189 | 
            +
            def text_normalize(text):
         | 
| 190 | 
            +
                numbers = re.findall(r"\d+(?:\.?\d+)?", text)
         | 
| 191 | 
            +
                for number in numbers:
         | 
| 192 | 
            +
                    text = text.replace(number, cn2an.an2cn(number), 1)
         | 
| 193 | 
            +
                text = replace_punctuation(text)
         | 
| 194 | 
            +
                return text
         | 
| 195 | 
            +
             | 
| 196 | 
            +
             | 
| 197 | 
            +
            def get_bert_feature(text, word2ph, device):
         | 
| 198 | 
            +
                from . import chinese_bert
         | 
| 199 | 
            +
                return chinese_bert.get_bert_feature(text, word2ph, model_id='bert-base-multilingual-uncased', device=device)
         | 
| 200 | 
            +
             | 
| 201 | 
            +
            from .chinese import _g2p as _chinese_g2p
         | 
| 202 | 
            +
            def _g2p_v2(segments):
         | 
| 203 | 
            +
                spliter = '#$&^!@'
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                phones_list = []
         | 
| 206 | 
            +
                tones_list = []
         | 
| 207 | 
            +
                word2ph = []
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                for text in segments:
         | 
| 210 | 
            +
                    assert spliter not in text
         | 
| 211 | 
            +
                    # replace all english words
         | 
| 212 | 
            +
                    text = re.sub('([a-zA-Z\s]+)', lambda x: f'{spliter}{x.group(1)}{spliter}', text)
         | 
| 213 | 
            +
                    texts = text.split(spliter)
         | 
| 214 | 
            +
                    texts = [t for t in texts if len(t) > 0]
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                    
         | 
| 217 | 
            +
                    for text in texts:
         | 
| 218 | 
            +
                        if re.match('[a-zA-Z\s]+', text):
         | 
| 219 | 
            +
                            # english
         | 
| 220 | 
            +
                            tokenized_en = tokenizer.tokenize(text)
         | 
| 221 | 
            +
                            phones_en, tones_en, word2ph_en = g2p_en(text=None, pad_start_end=False, tokenized=tokenized_en)
         | 
| 222 | 
            +
                            # apply offset to tones_en
         | 
| 223 | 
            +
                            tones_en = [t + language_tone_start_map['EN'] for t in tones_en]
         | 
| 224 | 
            +
                            phones_list += phones_en
         | 
| 225 | 
            +
                            tones_list += tones_en
         | 
| 226 | 
            +
                            word2ph += word2ph_en
         | 
| 227 | 
            +
                        else:
         | 
| 228 | 
            +
                            phones_zh, tones_zh, word2ph_zh = _chinese_g2p([text])
         | 
| 229 | 
            +
                            phones_list += phones_zh
         | 
| 230 | 
            +
                            tones_list += tones_zh
         | 
| 231 | 
            +
                            word2ph += word2ph_zh
         | 
| 232 | 
            +
                return phones_list, tones_list, word2ph
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                
         | 
| 235 | 
            +
             | 
| 236 | 
            +
            if __name__ == "__main__":
         | 
| 237 | 
            +
                # from text.chinese_bert import get_bert_feature
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                text = "NFT啊!chemistry 但是《原神》是由,米哈\游自主,  [研发]的一款全.新开放世界.冒险游戏"
         | 
| 240 | 
            +
                text = '我最近在学习machine learning,希望能够在未来的artificial intelligence领域有所建树。'
         | 
| 241 | 
            +
                text = '今天下午,我们准备去shopping mall购物,然后晚上去看一场movie。'
         | 
| 242 | 
            +
                text = '我们现在 also 能够 help 很多公司 use some machine learning 的 algorithms 啊!'
         | 
| 243 | 
            +
                text = text_normalize(text)
         | 
| 244 | 
            +
                print(text)
         | 
| 245 | 
            +
                phones, tones, word2ph = g2p(text, impl='v2')
         | 
| 246 | 
            +
                bert = get_bert_feature(text, word2ph, device='cuda:0')
         | 
| 247 | 
            +
                print(phones)
         | 
| 248 | 
            +
                import pdb; pdb.set_trace()
         | 
| 249 | 
            +
             | 
| 250 | 
            +
             | 
| 251 | 
            +
            # # 示例用法
         | 
| 252 | 
            +
            # text = "这是一个示例文本:,你好!这是一个测试...."
         | 
| 253 | 
            +
            # print(g2p_paddle(text))  # 输出: 这是一个示例文本你好这是一个测试
         | 
    	
        melo/text/cleaner.py
    ADDED
    
    | @@ -0,0 +1,36 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from . import chinese, japanese, english, chinese_mix, korean, french, spanish
         | 
| 2 | 
            +
            from . import cleaned_text_to_sequence
         | 
| 3 | 
            +
            import copy
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            language_module_map = {"ZH": chinese, "JP": japanese, "EN": english, 'ZH_MIX_EN': chinese_mix, 'KR': korean,
         | 
| 6 | 
            +
                                'FR': french, 'SP': spanish, 'ES': spanish}
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            def clean_text(text, language):
         | 
| 10 | 
            +
                language_module = language_module_map[language]
         | 
| 11 | 
            +
                norm_text = language_module.text_normalize(text)
         | 
| 12 | 
            +
                phones, tones, word2ph = language_module.g2p(norm_text)
         | 
| 13 | 
            +
                return norm_text, phones, tones, word2ph
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            def clean_text_bert(text, language, device=None):
         | 
| 17 | 
            +
                language_module = language_module_map[language]
         | 
| 18 | 
            +
                norm_text = language_module.text_normalize(text)
         | 
| 19 | 
            +
                phones, tones, word2ph = language_module.g2p(norm_text)
         | 
| 20 | 
            +
                
         | 
| 21 | 
            +
                word2ph_bak = copy.deepcopy(word2ph)
         | 
| 22 | 
            +
                for i in range(len(word2ph)):
         | 
| 23 | 
            +
                    word2ph[i] = word2ph[i] * 2
         | 
| 24 | 
            +
                word2ph[0] += 1
         | 
| 25 | 
            +
                bert = language_module.get_bert_feature(norm_text, word2ph, device=device)
         | 
| 26 | 
            +
                
         | 
| 27 | 
            +
                return norm_text, phones, tones, word2ph_bak, bert
         | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
            def text_to_sequence(text, language):
         | 
| 31 | 
            +
                norm_text, phones, tones, word2ph = clean_text(text, language)
         | 
| 32 | 
            +
                return cleaned_text_to_sequence(phones, tones, language)
         | 
| 33 | 
            +
             | 
| 34 | 
            +
             | 
| 35 | 
            +
            if __name__ == "__main__":
         | 
| 36 | 
            +
                pass
         | 
    	
        melo/text/cleaner_multiling.py
    ADDED
    
    | @@ -0,0 +1,110 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """Set of default text cleaners"""
         | 
| 2 | 
            +
            # TODO: pick the cleaner for languages dynamically
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import re
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            # Regular expression matching whitespace:
         | 
| 7 | 
            +
            _whitespace_re = re.compile(r"\s+")
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            rep_map = {
         | 
| 10 | 
            +
                ":": ",",
         | 
| 11 | 
            +
                ";": ",",
         | 
| 12 | 
            +
                ",": ",",
         | 
| 13 | 
            +
                "。": ".",
         | 
| 14 | 
            +
                "!": "!",
         | 
| 15 | 
            +
                "?": "?",
         | 
| 16 | 
            +
                "\n": ".",
         | 
| 17 | 
            +
                "·": ",",
         | 
| 18 | 
            +
                "、": ",",
         | 
| 19 | 
            +
                "...": ".",
         | 
| 20 | 
            +
                "…": ".",
         | 
| 21 | 
            +
                "$": ".",
         | 
| 22 | 
            +
                "“": "'",
         | 
| 23 | 
            +
                "”": "'",
         | 
| 24 | 
            +
                "‘": "'",
         | 
| 25 | 
            +
                "’": "'",
         | 
| 26 | 
            +
                "(": "'",
         | 
| 27 | 
            +
                ")": "'",
         | 
| 28 | 
            +
                "(": "'",
         | 
| 29 | 
            +
                ")": "'",
         | 
| 30 | 
            +
                "《": "'",
         | 
| 31 | 
            +
                "》": "'",
         | 
| 32 | 
            +
                "【": "'",
         | 
| 33 | 
            +
                "】": "'",
         | 
| 34 | 
            +
                "[": "'",
         | 
| 35 | 
            +
                "]": "'",
         | 
| 36 | 
            +
                "—": "",
         | 
| 37 | 
            +
                "~": "-",
         | 
| 38 | 
            +
                "~": "-",
         | 
| 39 | 
            +
                "「": "'",
         | 
| 40 | 
            +
                "」": "'",
         | 
| 41 | 
            +
            }
         | 
| 42 | 
            +
             | 
| 43 | 
            +
            def replace_punctuation(text):
         | 
| 44 | 
            +
                pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
         | 
| 45 | 
            +
                replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
         | 
| 46 | 
            +
                return replaced_text
         | 
| 47 | 
            +
             | 
| 48 | 
            +
            def lowercase(text):
         | 
| 49 | 
            +
                return text.lower()
         | 
| 50 | 
            +
             | 
| 51 | 
            +
             | 
| 52 | 
            +
            def collapse_whitespace(text):
         | 
| 53 | 
            +
                return re.sub(_whitespace_re, " ", text).strip()
         | 
| 54 | 
            +
             | 
| 55 | 
            +
            def remove_punctuation_at_begin(text):
         | 
| 56 | 
            +
                return re.sub(r'^[,.!?]+', '', text)
         | 
| 57 | 
            +
             | 
| 58 | 
            +
            def remove_aux_symbols(text):
         | 
| 59 | 
            +
                text = re.sub(r"[\<\>\(\)\[\]\"\«\»\']+", "", text)
         | 
| 60 | 
            +
                return text
         | 
| 61 | 
            +
             | 
| 62 | 
            +
             | 
| 63 | 
            +
            def replace_symbols(text, lang="en"):
         | 
| 64 | 
            +
                """Replace symbols based on the lenguage tag.
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                Args:
         | 
| 67 | 
            +
                  text:
         | 
| 68 | 
            +
                   Input text.
         | 
| 69 | 
            +
                  lang:
         | 
| 70 | 
            +
                    Lenguage identifier. ex: "en", "fr", "pt", "ca".
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                Returns:
         | 
| 73 | 
            +
                  The modified text
         | 
| 74 | 
            +
                  example:
         | 
| 75 | 
            +
                    input args:
         | 
| 76 | 
            +
                        text: "si l'avi cau, diguem-ho"
         | 
| 77 | 
            +
                        lang: "ca"
         | 
| 78 | 
            +
                    Output:
         | 
| 79 | 
            +
                        text: "si lavi cau, diguemho"
         | 
| 80 | 
            +
                """
         | 
| 81 | 
            +
                text = text.replace(";", ",")
         | 
| 82 | 
            +
                text = text.replace("-", " ") if lang != "ca" else text.replace("-", "")
         | 
| 83 | 
            +
                text = text.replace(":", ",")
         | 
| 84 | 
            +
                if lang == "en":
         | 
| 85 | 
            +
                    text = text.replace("&", " and ")
         | 
| 86 | 
            +
                elif lang == "fr":
         | 
| 87 | 
            +
                    text = text.replace("&", " et ")
         | 
| 88 | 
            +
                elif lang == "pt":
         | 
| 89 | 
            +
                    text = text.replace("&", " e ")
         | 
| 90 | 
            +
                elif lang == "ca":
         | 
| 91 | 
            +
                    text = text.replace("&", " i ")
         | 
| 92 | 
            +
                    text = text.replace("'", "")
         | 
| 93 | 
            +
                elif lang== "es":
         | 
| 94 | 
            +
                    text=text.replace("&","y")
         | 
| 95 | 
            +
                    text = text.replace("'", "")
         | 
| 96 | 
            +
                return text
         | 
| 97 | 
            +
             | 
| 98 | 
            +
            def unicleaners(text, cased=False, lang='en'):
         | 
| 99 | 
            +
                """Basic pipeline for Portuguese text. There is no need to expand abbreviation and
         | 
| 100 | 
            +
                numbers, phonemizer already does that"""
         | 
| 101 | 
            +
                if not cased:
         | 
| 102 | 
            +
                    text = lowercase(text)
         | 
| 103 | 
            +
                text = replace_punctuation(text)
         | 
| 104 | 
            +
                text = replace_symbols(text, lang=lang)
         | 
| 105 | 
            +
                text = remove_aux_symbols(text)
         | 
| 106 | 
            +
                text = remove_punctuation_at_begin(text)
         | 
| 107 | 
            +
                text = collapse_whitespace(text)
         | 
| 108 | 
            +
                text = re.sub(r'([^\.,!\?\-…])$', r'\1.', text)
         | 
| 109 | 
            +
                return text
         | 
| 110 | 
            +
             | 
    	
        melo/text/cmudict.rep
    ADDED
    
    | The diff for this file is too large to render. 
		See raw diff | 
|  | 
    	
        melo/text/cmudict_cache.pickle
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:b9b21b20325471934ba92f2e4a5976989e7d920caa32e7a286eacb027d197949
         | 
| 3 | 
            +
            size 6212655
         | 
    	
        melo/text/english.py
    ADDED
    
    | @@ -0,0 +1,284 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import pickle
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            import re
         | 
| 4 | 
            +
            from g2p_en import G2p
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from . import symbols
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from .english_utils.abbreviations import expand_abbreviations
         | 
| 9 | 
            +
            from .english_utils.time_norm import expand_time_english
         | 
| 10 | 
            +
            from .english_utils.number_norm import normalize_numbers
         | 
| 11 | 
            +
            from .japanese import distribute_phone
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            from transformers import AutoTokenizer
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            current_file_path = os.path.dirname(__file__)
         | 
| 16 | 
            +
            CMU_DICT_PATH = os.path.join(current_file_path, "cmudict.rep")
         | 
| 17 | 
            +
            CACHE_PATH = os.path.join(current_file_path, "cmudict_cache.pickle")
         | 
| 18 | 
            +
            _g2p = G2p()
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            arpa = {
         | 
| 21 | 
            +
                "AH0",
         | 
| 22 | 
            +
                "S",
         | 
| 23 | 
            +
                "AH1",
         | 
| 24 | 
            +
                "EY2",
         | 
| 25 | 
            +
                "AE2",
         | 
| 26 | 
            +
                "EH0",
         | 
| 27 | 
            +
                "OW2",
         | 
| 28 | 
            +
                "UH0",
         | 
| 29 | 
            +
                "NG",
         | 
| 30 | 
            +
                "B",
         | 
| 31 | 
            +
                "G",
         | 
| 32 | 
            +
                "AY0",
         | 
| 33 | 
            +
                "M",
         | 
| 34 | 
            +
                "AA0",
         | 
| 35 | 
            +
                "F",
         | 
| 36 | 
            +
                "AO0",
         | 
| 37 | 
            +
                "ER2",
         | 
| 38 | 
            +
                "UH1",
         | 
| 39 | 
            +
                "IY1",
         | 
| 40 | 
            +
                "AH2",
         | 
| 41 | 
            +
                "DH",
         | 
| 42 | 
            +
                "IY0",
         | 
| 43 | 
            +
                "EY1",
         | 
| 44 | 
            +
                "IH0",
         | 
| 45 | 
            +
                "K",
         | 
| 46 | 
            +
                "N",
         | 
| 47 | 
            +
                "W",
         | 
| 48 | 
            +
                "IY2",
         | 
| 49 | 
            +
                "T",
         | 
| 50 | 
            +
                "AA1",
         | 
| 51 | 
            +
                "ER1",
         | 
| 52 | 
            +
                "EH2",
         | 
| 53 | 
            +
                "OY0",
         | 
| 54 | 
            +
                "UH2",
         | 
| 55 | 
            +
                "UW1",
         | 
| 56 | 
            +
                "Z",
         | 
| 57 | 
            +
                "AW2",
         | 
| 58 | 
            +
                "AW1",
         | 
| 59 | 
            +
                "V",
         | 
| 60 | 
            +
                "UW2",
         | 
| 61 | 
            +
                "AA2",
         | 
| 62 | 
            +
                "ER",
         | 
| 63 | 
            +
                "AW0",
         | 
| 64 | 
            +
                "UW0",
         | 
| 65 | 
            +
                "R",
         | 
| 66 | 
            +
                "OW1",
         | 
| 67 | 
            +
                "EH1",
         | 
| 68 | 
            +
                "ZH",
         | 
| 69 | 
            +
                "AE0",
         | 
| 70 | 
            +
                "IH2",
         | 
| 71 | 
            +
                "IH",
         | 
| 72 | 
            +
                "Y",
         | 
| 73 | 
            +
                "JH",
         | 
| 74 | 
            +
                "P",
         | 
| 75 | 
            +
                "AY1",
         | 
| 76 | 
            +
                "EY0",
         | 
| 77 | 
            +
                "OY2",
         | 
| 78 | 
            +
                "TH",
         | 
| 79 | 
            +
                "HH",
         | 
| 80 | 
            +
                "D",
         | 
| 81 | 
            +
                "ER0",
         | 
| 82 | 
            +
                "CH",
         | 
| 83 | 
            +
                "AO1",
         | 
| 84 | 
            +
                "AE1",
         | 
| 85 | 
            +
                "AO2",
         | 
| 86 | 
            +
                "OY1",
         | 
| 87 | 
            +
                "AY2",
         | 
| 88 | 
            +
                "IH1",
         | 
| 89 | 
            +
                "OW0",
         | 
| 90 | 
            +
                "L",
         | 
| 91 | 
            +
                "SH",
         | 
| 92 | 
            +
            }
         | 
| 93 | 
            +
             | 
| 94 | 
            +
             | 
| 95 | 
            +
            def post_replace_ph(ph):
         | 
| 96 | 
            +
                rep_map = {
         | 
| 97 | 
            +
                    ":": ",",
         | 
| 98 | 
            +
                    ";": ",",
         | 
| 99 | 
            +
                    ",": ",",
         | 
| 100 | 
            +
                    "。": ".",
         | 
| 101 | 
            +
                    "!": "!",
         | 
| 102 | 
            +
                    "?": "?",
         | 
| 103 | 
            +
                    "\n": ".",
         | 
| 104 | 
            +
                    "·": ",",
         | 
| 105 | 
            +
                    "、": ",",
         | 
| 106 | 
            +
                    "...": "…",
         | 
| 107 | 
            +
                    "v": "V",
         | 
| 108 | 
            +
                }
         | 
| 109 | 
            +
                if ph in rep_map.keys():
         | 
| 110 | 
            +
                    ph = rep_map[ph]
         | 
| 111 | 
            +
                if ph in symbols:
         | 
| 112 | 
            +
                    return ph
         | 
| 113 | 
            +
                if ph not in symbols:
         | 
| 114 | 
            +
                    ph = "UNK"
         | 
| 115 | 
            +
                return ph
         | 
| 116 | 
            +
             | 
| 117 | 
            +
             | 
| 118 | 
            +
            def read_dict():
         | 
| 119 | 
            +
                g2p_dict = {}
         | 
| 120 | 
            +
                start_line = 49
         | 
| 121 | 
            +
                with open(CMU_DICT_PATH) as f:
         | 
| 122 | 
            +
                    line = f.readline()
         | 
| 123 | 
            +
                    line_index = 1
         | 
| 124 | 
            +
                    while line:
         | 
| 125 | 
            +
                        if line_index >= start_line:
         | 
| 126 | 
            +
                            line = line.strip()
         | 
| 127 | 
            +
                            word_split = line.split("  ")
         | 
| 128 | 
            +
                            word = word_split[0]
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                            syllable_split = word_split[1].split(" - ")
         | 
| 131 | 
            +
                            g2p_dict[word] = []
         | 
| 132 | 
            +
                            for syllable in syllable_split:
         | 
| 133 | 
            +
                                phone_split = syllable.split(" ")
         | 
| 134 | 
            +
                                g2p_dict[word].append(phone_split)
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                        line_index = line_index + 1
         | 
| 137 | 
            +
                        line = f.readline()
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                return g2p_dict
         | 
| 140 | 
            +
             | 
| 141 | 
            +
             | 
| 142 | 
            +
            def cache_dict(g2p_dict, file_path):
         | 
| 143 | 
            +
                with open(file_path, "wb") as pickle_file:
         | 
| 144 | 
            +
                    pickle.dump(g2p_dict, pickle_file)
         | 
| 145 | 
            +
             | 
| 146 | 
            +
             | 
| 147 | 
            +
            def get_dict():
         | 
| 148 | 
            +
                if os.path.exists(CACHE_PATH):
         | 
| 149 | 
            +
                    with open(CACHE_PATH, "rb") as pickle_file:
         | 
| 150 | 
            +
                        g2p_dict = pickle.load(pickle_file)
         | 
| 151 | 
            +
                else:
         | 
| 152 | 
            +
                    g2p_dict = read_dict()
         | 
| 153 | 
            +
                    cache_dict(g2p_dict, CACHE_PATH)
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                return g2p_dict
         | 
| 156 | 
            +
             | 
| 157 | 
            +
             | 
| 158 | 
            +
            eng_dict = get_dict()
         | 
| 159 | 
            +
             | 
| 160 | 
            +
             | 
| 161 | 
            +
            def refine_ph(phn):
         | 
| 162 | 
            +
                tone = 0
         | 
| 163 | 
            +
                if re.search(r"\d$", phn):
         | 
| 164 | 
            +
                    tone = int(phn[-1]) + 1
         | 
| 165 | 
            +
                    phn = phn[:-1]
         | 
| 166 | 
            +
                return phn.lower(), tone
         | 
| 167 | 
            +
             | 
| 168 | 
            +
             | 
| 169 | 
            +
            def refine_syllables(syllables):
         | 
| 170 | 
            +
                tones = []
         | 
| 171 | 
            +
                phonemes = []
         | 
| 172 | 
            +
                for phn_list in syllables:
         | 
| 173 | 
            +
                    for i in range(len(phn_list)):
         | 
| 174 | 
            +
                        phn = phn_list[i]
         | 
| 175 | 
            +
                        phn, tone = refine_ph(phn)
         | 
| 176 | 
            +
                        phonemes.append(phn)
         | 
| 177 | 
            +
                        tones.append(tone)
         | 
| 178 | 
            +
                return phonemes, tones
         | 
| 179 | 
            +
             | 
| 180 | 
            +
             | 
| 181 | 
            +
            def text_normalize(text):
         | 
| 182 | 
            +
                text = text.lower()
         | 
| 183 | 
            +
                text = expand_time_english(text)
         | 
| 184 | 
            +
                text = normalize_numbers(text)
         | 
| 185 | 
            +
                text = expand_abbreviations(text)
         | 
| 186 | 
            +
                return text
         | 
| 187 | 
            +
             | 
| 188 | 
            +
            model_id = 'bert-base-uncased'
         | 
| 189 | 
            +
            tokenizer = AutoTokenizer.from_pretrained(model_id)
         | 
| 190 | 
            +
            def g2p_old(text):
         | 
| 191 | 
            +
                tokenized = tokenizer.tokenize(text)
         | 
| 192 | 
            +
                # import pdb; pdb.set_trace()
         | 
| 193 | 
            +
                phones = []
         | 
| 194 | 
            +
                tones = []
         | 
| 195 | 
            +
                words = re.split(r"([,;.\-\?\!\s+])", text)
         | 
| 196 | 
            +
                for w in words:
         | 
| 197 | 
            +
                    if w.upper() in eng_dict:
         | 
| 198 | 
            +
                        phns, tns = refine_syllables(eng_dict[w.upper()])
         | 
| 199 | 
            +
                        phones += phns
         | 
| 200 | 
            +
                        tones += tns
         | 
| 201 | 
            +
                    else:
         | 
| 202 | 
            +
                        phone_list = list(filter(lambda p: p != " ", _g2p(w)))
         | 
| 203 | 
            +
                        for ph in phone_list:
         | 
| 204 | 
            +
                            if ph in arpa:
         | 
| 205 | 
            +
                                ph, tn = refine_ph(ph)
         | 
| 206 | 
            +
                                phones.append(ph)
         | 
| 207 | 
            +
                                tones.append(tn)
         | 
| 208 | 
            +
                            else:
         | 
| 209 | 
            +
                                phones.append(ph)
         | 
| 210 | 
            +
                                tones.append(0)
         | 
| 211 | 
            +
                # todo: implement word2ph
         | 
| 212 | 
            +
                word2ph = [1 for i in phones]
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                phones = [post_replace_ph(i) for i in phones]
         | 
| 215 | 
            +
                return phones, tones, word2ph
         | 
| 216 | 
            +
             | 
| 217 | 
            +
            def g2p(text, pad_start_end=True, tokenized=None):
         | 
| 218 | 
            +
                if tokenized is None:
         | 
| 219 | 
            +
                    tokenized = tokenizer.tokenize(text)
         | 
| 220 | 
            +
                # import pdb; pdb.set_trace()
         | 
| 221 | 
            +
                phs = []
         | 
| 222 | 
            +
                ph_groups = []
         | 
| 223 | 
            +
                for t in tokenized:
         | 
| 224 | 
            +
                    if not t.startswith("#"):
         | 
| 225 | 
            +
                        ph_groups.append([t])
         | 
| 226 | 
            +
                    else:
         | 
| 227 | 
            +
                        ph_groups[-1].append(t.replace("#", ""))
         | 
| 228 | 
            +
                
         | 
| 229 | 
            +
                phones = []
         | 
| 230 | 
            +
                tones = []
         | 
| 231 | 
            +
                word2ph = []
         | 
| 232 | 
            +
                for group in ph_groups:
         | 
| 233 | 
            +
                    w = "".join(group)
         | 
| 234 | 
            +
                    phone_len = 0
         | 
| 235 | 
            +
                    word_len = len(group)
         | 
| 236 | 
            +
                    if w.upper() in eng_dict:
         | 
| 237 | 
            +
                        phns, tns = refine_syllables(eng_dict[w.upper()])
         | 
| 238 | 
            +
                        phones += phns
         | 
| 239 | 
            +
                        tones += tns
         | 
| 240 | 
            +
                        phone_len += len(phns)
         | 
| 241 | 
            +
                    else:
         | 
| 242 | 
            +
                        phone_list = list(filter(lambda p: p != " ", _g2p(w)))
         | 
| 243 | 
            +
                        for ph in phone_list:
         | 
| 244 | 
            +
                            if ph in arpa:
         | 
| 245 | 
            +
                                ph, tn = refine_ph(ph)
         | 
| 246 | 
            +
                                phones.append(ph)
         | 
| 247 | 
            +
                                tones.append(tn)
         | 
| 248 | 
            +
                            else:
         | 
| 249 | 
            +
                                phones.append(ph)
         | 
| 250 | 
            +
                                tones.append(0)
         | 
| 251 | 
            +
                            phone_len += 1
         | 
| 252 | 
            +
                    aaa = distribute_phone(phone_len, word_len)
         | 
| 253 | 
            +
                    word2ph += aaa
         | 
| 254 | 
            +
                phones = [post_replace_ph(i) for i in phones]
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                if pad_start_end:
         | 
| 257 | 
            +
                    phones = ["_"] + phones + ["_"]
         | 
| 258 | 
            +
                    tones = [0] + tones + [0]
         | 
| 259 | 
            +
                    word2ph = [1] + word2ph + [1]
         | 
| 260 | 
            +
                return phones, tones, word2ph
         | 
| 261 | 
            +
             | 
| 262 | 
            +
            def get_bert_feature(text, word2ph, device=None):
         | 
| 263 | 
            +
                from text import english_bert
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                return english_bert.get_bert_feature(text, word2ph, device=device)
         | 
| 266 | 
            +
             | 
| 267 | 
            +
            if __name__ == "__main__":
         | 
| 268 | 
            +
                # print(get_dict())
         | 
| 269 | 
            +
                # print(eng_word_to_phoneme("hello"))
         | 
| 270 | 
            +
                from text.english_bert import get_bert_feature
         | 
| 271 | 
            +
                text = "In this paper, we propose 1 DSPGAN, a N-F-T GAN-based universal vocoder."
         | 
| 272 | 
            +
                text = text_normalize(text)
         | 
| 273 | 
            +
                phones, tones, word2ph = g2p(text)
         | 
| 274 | 
            +
                import pdb; pdb.set_trace()
         | 
| 275 | 
            +
                bert = get_bert_feature(text, word2ph)
         | 
| 276 | 
            +
                
         | 
| 277 | 
            +
                print(phones, tones, word2ph, bert.shape)
         | 
| 278 | 
            +
             | 
| 279 | 
            +
                # all_phones = set()
         | 
| 280 | 
            +
                # for k, syllables in eng_dict.items():
         | 
| 281 | 
            +
                #     for group in syllables:
         | 
| 282 | 
            +
                #         for ph in group:
         | 
| 283 | 
            +
                #             all_phones.add(ph)
         | 
| 284 | 
            +
                # print(all_phones)
         | 
    	
        melo/text/english_bert.py
    ADDED
    
    | @@ -0,0 +1,39 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            from transformers import AutoTokenizer, AutoModelForMaskedLM
         | 
| 3 | 
            +
            import sys
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            model_id = 'bert-base-uncased'
         | 
| 6 | 
            +
            tokenizer = AutoTokenizer.from_pretrained(model_id)
         | 
| 7 | 
            +
            model = None
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            def get_bert_feature(text, word2ph, device=None):
         | 
| 10 | 
            +
                global model
         | 
| 11 | 
            +
                if (
         | 
| 12 | 
            +
                    sys.platform == "darwin"
         | 
| 13 | 
            +
                    and torch.backends.mps.is_available()
         | 
| 14 | 
            +
                    and device == "cpu"
         | 
| 15 | 
            +
                ):
         | 
| 16 | 
            +
                    device = "mps"
         | 
| 17 | 
            +
                if not device:
         | 
| 18 | 
            +
                    device = "cuda"
         | 
| 19 | 
            +
                if model is None:
         | 
| 20 | 
            +
                    model = AutoModelForMaskedLM.from_pretrained(model_id).to(
         | 
| 21 | 
            +
                        device
         | 
| 22 | 
            +
                    )
         | 
| 23 | 
            +
                with torch.no_grad():
         | 
| 24 | 
            +
                    inputs = tokenizer(text, return_tensors="pt")
         | 
| 25 | 
            +
                    for i in inputs:
         | 
| 26 | 
            +
                        inputs[i] = inputs[i].to(device)
         | 
| 27 | 
            +
                    res = model(**inputs, output_hidden_states=True)
         | 
| 28 | 
            +
                    res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
         | 
| 29 | 
            +
                    
         | 
| 30 | 
            +
                assert inputs["input_ids"].shape[-1] == len(word2ph)
         | 
| 31 | 
            +
                word2phone = word2ph
         | 
| 32 | 
            +
                phone_level_feature = []
         | 
| 33 | 
            +
                for i in range(len(word2phone)):
         | 
| 34 | 
            +
                    repeat_feature = res[i].repeat(word2phone[i], 1)
         | 
| 35 | 
            +
                    phone_level_feature.append(repeat_feature)
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                phone_level_feature = torch.cat(phone_level_feature, dim=0)
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                return phone_level_feature.T
         | 
    	
        melo/text/english_utils/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        melo/text/english_utils/abbreviations.py
    ADDED
    
    | @@ -0,0 +1,35 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import re
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            # List of (regular expression, replacement) pairs for abbreviations in english:
         | 
| 4 | 
            +
            abbreviations_en = [
         | 
| 5 | 
            +
                (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
         | 
| 6 | 
            +
                for x in [
         | 
| 7 | 
            +
                    ("mrs", "misess"),
         | 
| 8 | 
            +
                    ("mr", "mister"),
         | 
| 9 | 
            +
                    ("dr", "doctor"),
         | 
| 10 | 
            +
                    ("st", "saint"),
         | 
| 11 | 
            +
                    ("co", "company"),
         | 
| 12 | 
            +
                    ("jr", "junior"),
         | 
| 13 | 
            +
                    ("maj", "major"),
         | 
| 14 | 
            +
                    ("gen", "general"),
         | 
| 15 | 
            +
                    ("drs", "doctors"),
         | 
| 16 | 
            +
                    ("rev", "reverend"),
         | 
| 17 | 
            +
                    ("lt", "lieutenant"),
         | 
| 18 | 
            +
                    ("hon", "honorable"),
         | 
| 19 | 
            +
                    ("sgt", "sergeant"),
         | 
| 20 | 
            +
                    ("capt", "captain"),
         | 
| 21 | 
            +
                    ("esq", "esquire"),
         | 
| 22 | 
            +
                    ("ltd", "limited"),
         | 
| 23 | 
            +
                    ("col", "colonel"),
         | 
| 24 | 
            +
                    ("ft", "fort"),
         | 
| 25 | 
            +
                ]
         | 
| 26 | 
            +
            ]
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            def expand_abbreviations(text, lang="en"):
         | 
| 29 | 
            +
                if lang == "en":
         | 
| 30 | 
            +
                    _abbreviations = abbreviations_en
         | 
| 31 | 
            +
                else:
         | 
| 32 | 
            +
                    raise NotImplementedError()
         | 
| 33 | 
            +
                for regex, replacement in _abbreviations:
         | 
| 34 | 
            +
                    text = re.sub(regex, replacement, text)
         | 
| 35 | 
            +
                return text
         | 
    	
        melo/text/english_utils/number_norm.py
    ADDED
    
    | @@ -0,0 +1,97 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """ from https://github.com/keithito/tacotron """
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import re
         | 
| 4 | 
            +
            from typing import Dict
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import inflect
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            _inflect = inflect.engine()
         | 
| 9 | 
            +
            _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
         | 
| 10 | 
            +
            _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
         | 
| 11 | 
            +
            _currency_re = re.compile(r"(£|\$|¥)([0-9\,\.]*[0-9]+)")
         | 
| 12 | 
            +
            _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
         | 
| 13 | 
            +
            _number_re = re.compile(r"-?[0-9]+")
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            def _remove_commas(m):
         | 
| 17 | 
            +
                return m.group(1).replace(",", "")
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            def _expand_decimal_point(m):
         | 
| 21 | 
            +
                return m.group(1).replace(".", " point ")
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            def __expand_currency(value: str, inflection: Dict[float, str]) -> str:
         | 
| 25 | 
            +
                parts = value.replace(",", "").split(".")
         | 
| 26 | 
            +
                if len(parts) > 2:
         | 
| 27 | 
            +
                    return f"{value} {inflection[2]}"  # Unexpected format
         | 
| 28 | 
            +
                text = []
         | 
| 29 | 
            +
                integer = int(parts[0]) if parts[0] else 0
         | 
| 30 | 
            +
                if integer > 0:
         | 
| 31 | 
            +
                    integer_unit = inflection.get(integer, inflection[2])
         | 
| 32 | 
            +
                    text.append(f"{integer} {integer_unit}")
         | 
| 33 | 
            +
                fraction = int(parts[1]) if len(parts) > 1 and parts[1] else 0
         | 
| 34 | 
            +
                if fraction > 0:
         | 
| 35 | 
            +
                    fraction_unit = inflection.get(fraction / 100, inflection[0.02])
         | 
| 36 | 
            +
                    text.append(f"{fraction} {fraction_unit}")
         | 
| 37 | 
            +
                if len(text) == 0:
         | 
| 38 | 
            +
                    return f"zero {inflection[2]}"
         | 
| 39 | 
            +
                return " ".join(text)
         | 
| 40 | 
            +
             | 
| 41 | 
            +
             | 
| 42 | 
            +
            def _expand_currency(m: "re.Match") -> str:
         | 
| 43 | 
            +
                currencies = {
         | 
| 44 | 
            +
                    "$": {
         | 
| 45 | 
            +
                        0.01: "cent",
         | 
| 46 | 
            +
                        0.02: "cents",
         | 
| 47 | 
            +
                        1: "dollar",
         | 
| 48 | 
            +
                        2: "dollars",
         | 
| 49 | 
            +
                    },
         | 
| 50 | 
            +
                    "€": {
         | 
| 51 | 
            +
                        0.01: "cent",
         | 
| 52 | 
            +
                        0.02: "cents",
         | 
| 53 | 
            +
                        1: "euro",
         | 
| 54 | 
            +
                        2: "euros",
         | 
| 55 | 
            +
                    },
         | 
| 56 | 
            +
                    "£": {
         | 
| 57 | 
            +
                        0.01: "penny",
         | 
| 58 | 
            +
                        0.02: "pence",
         | 
| 59 | 
            +
                        1: "pound sterling",
         | 
| 60 | 
            +
                        2: "pounds sterling",
         | 
| 61 | 
            +
                    },
         | 
| 62 | 
            +
                    "¥": {
         | 
| 63 | 
            +
                        # TODO rin
         | 
| 64 | 
            +
                        0.02: "sen",
         | 
| 65 | 
            +
                        2: "yen",
         | 
| 66 | 
            +
                    },
         | 
| 67 | 
            +
                }
         | 
| 68 | 
            +
                unit = m.group(1)
         | 
| 69 | 
            +
                currency = currencies[unit]
         | 
| 70 | 
            +
                value = m.group(2)
         | 
| 71 | 
            +
                return __expand_currency(value, currency)
         | 
| 72 | 
            +
             | 
| 73 | 
            +
             | 
| 74 | 
            +
            def _expand_ordinal(m):
         | 
| 75 | 
            +
                return _inflect.number_to_words(m.group(0))
         | 
| 76 | 
            +
             | 
| 77 | 
            +
             | 
| 78 | 
            +
            def _expand_number(m):
         | 
| 79 | 
            +
                num = int(m.group(0))
         | 
| 80 | 
            +
                if 1000 < num < 3000:
         | 
| 81 | 
            +
                    if num == 2000:
         | 
| 82 | 
            +
                        return "two thousand"
         | 
| 83 | 
            +
                    if 2000 < num < 2010:
         | 
| 84 | 
            +
                        return "two thousand " + _inflect.number_to_words(num % 100)
         | 
| 85 | 
            +
                    if num % 100 == 0:
         | 
| 86 | 
            +
                        return _inflect.number_to_words(num // 100) + " hundred"
         | 
| 87 | 
            +
                    return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
         | 
| 88 | 
            +
                return _inflect.number_to_words(num, andword="")
         | 
| 89 | 
            +
             | 
| 90 | 
            +
             | 
| 91 | 
            +
            def normalize_numbers(text):
         | 
| 92 | 
            +
                text = re.sub(_comma_number_re, _remove_commas, text)
         | 
| 93 | 
            +
                text = re.sub(_currency_re, _expand_currency, text)
         | 
| 94 | 
            +
                text = re.sub(_decimal_number_re, _expand_decimal_point, text)
         | 
| 95 | 
            +
                text = re.sub(_ordinal_re, _expand_ordinal, text)
         | 
| 96 | 
            +
                text = re.sub(_number_re, _expand_number, text)
         | 
| 97 | 
            +
                return text
         | 
    	
        melo/text/english_utils/time_norm.py
    ADDED
    
    | @@ -0,0 +1,47 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import re
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import inflect
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            _inflect = inflect.engine()
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            _time_re = re.compile(
         | 
| 8 | 
            +
                r"""\b
         | 
| 9 | 
            +
                                      ((0?[0-9])|(1[0-1])|(1[2-9])|(2[0-3]))  # hours
         | 
| 10 | 
            +
                                      :
         | 
| 11 | 
            +
                                      ([0-5][0-9])                            # minutes
         | 
| 12 | 
            +
                                      \s*(a\\.m\\.|am|pm|p\\.m\\.|a\\.m|p\\.m)? # am/pm
         | 
| 13 | 
            +
                                      \b""",
         | 
| 14 | 
            +
                re.IGNORECASE | re.X,
         | 
| 15 | 
            +
            )
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            def _expand_num(n: int) -> str:
         | 
| 19 | 
            +
                return _inflect.number_to_words(n)
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            def _expand_time_english(match: "re.Match") -> str:
         | 
| 23 | 
            +
                hour = int(match.group(1))
         | 
| 24 | 
            +
                past_noon = hour >= 12
         | 
| 25 | 
            +
                time = []
         | 
| 26 | 
            +
                if hour > 12:
         | 
| 27 | 
            +
                    hour -= 12
         | 
| 28 | 
            +
                elif hour == 0:
         | 
| 29 | 
            +
                    hour = 12
         | 
| 30 | 
            +
                    past_noon = True
         | 
| 31 | 
            +
                time.append(_expand_num(hour))
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                minute = int(match.group(6))
         | 
| 34 | 
            +
                if minute > 0:
         | 
| 35 | 
            +
                    if minute < 10:
         | 
| 36 | 
            +
                        time.append("oh")
         | 
| 37 | 
            +
                    time.append(_expand_num(minute))
         | 
| 38 | 
            +
                am_pm = match.group(7)
         | 
| 39 | 
            +
                if am_pm is None:
         | 
| 40 | 
            +
                    time.append("p m" if past_noon else "a m")
         | 
| 41 | 
            +
                else:
         | 
| 42 | 
            +
                    time.extend(list(am_pm.replace(".", "")))
         | 
| 43 | 
            +
                return " ".join(time)
         | 
| 44 | 
            +
             | 
| 45 | 
            +
             | 
| 46 | 
            +
            def expand_time_english(text: str) -> str:
         | 
| 47 | 
            +
                return re.sub(_time_re, _expand_time_english, text)
         | 
    	
        melo/text/es_phonemizer/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        melo/text/es_phonemizer/base.py
    ADDED
    
    | @@ -0,0 +1,140 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import abc
         | 
| 2 | 
            +
            from typing import List, Tuple
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from .punctuation import Punctuation
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            class BasePhonemizer(abc.ABC):
         | 
| 8 | 
            +
                """Base phonemizer class
         | 
| 9 | 
            +
             | 
| 10 | 
            +
                Phonemization follows the following steps:
         | 
| 11 | 
            +
                    1. Preprocessing:
         | 
| 12 | 
            +
                        - remove empty lines
         | 
| 13 | 
            +
                        - remove punctuation
         | 
| 14 | 
            +
                        - keep track of punctuation marks
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                    2. Phonemization:
         | 
| 17 | 
            +
                        - convert text to phonemes
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                    3. Postprocessing:
         | 
| 20 | 
            +
                        - join phonemes
         | 
| 21 | 
            +
                        - restore punctuation marks
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                Args:
         | 
| 24 | 
            +
                    language (str):
         | 
| 25 | 
            +
                        Language used by the phonemizer.
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                    punctuations (List[str]):
         | 
| 28 | 
            +
                        List of punctuation marks to be preserved.
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                    keep_puncs (bool):
         | 
| 31 | 
            +
                        Whether to preserve punctuation marks or not.
         | 
| 32 | 
            +
                """
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                def __init__(self, language, punctuations=Punctuation.default_puncs(), keep_puncs=False):
         | 
| 35 | 
            +
                    # ensure the backend is installed on the system
         | 
| 36 | 
            +
                    if not self.is_available():
         | 
| 37 | 
            +
                        raise RuntimeError("{} not installed on your system".format(self.name()))  # pragma: nocover
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                    # ensure the backend support the requested language
         | 
| 40 | 
            +
                    self._language = self._init_language(language)
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                    # setup punctuation processing
         | 
| 43 | 
            +
                    self._keep_puncs = keep_puncs
         | 
| 44 | 
            +
                    self._punctuator = Punctuation(punctuations)
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                def _init_language(self, language):
         | 
| 47 | 
            +
                    """Language initialization
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                    This method may be overloaded in child classes (see Segments backend)
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                    """
         | 
| 52 | 
            +
                    if not self.is_supported_language(language):
         | 
| 53 | 
            +
                        raise RuntimeError(f'language "{language}" is not supported by the ' f"{self.name()} backend")
         | 
| 54 | 
            +
                    return language
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                @property
         | 
| 57 | 
            +
                def language(self):
         | 
| 58 | 
            +
                    """The language code configured to be used for phonemization"""
         | 
| 59 | 
            +
                    return self._language
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                @staticmethod
         | 
| 62 | 
            +
                @abc.abstractmethod
         | 
| 63 | 
            +
                def name():
         | 
| 64 | 
            +
                    """The name of the backend"""
         | 
| 65 | 
            +
                    ...
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                @classmethod
         | 
| 68 | 
            +
                @abc.abstractmethod
         | 
| 69 | 
            +
                def is_available(cls):
         | 
| 70 | 
            +
                    """Returns True if the backend is installed, False otherwise"""
         | 
| 71 | 
            +
                    ...
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                @classmethod
         | 
| 74 | 
            +
                @abc.abstractmethod
         | 
| 75 | 
            +
                def version(cls):
         | 
| 76 | 
            +
                    """Return the backend version as a tuple (major, minor, patch)"""
         | 
| 77 | 
            +
                    ...
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                @staticmethod
         | 
| 80 | 
            +
                @abc.abstractmethod
         | 
| 81 | 
            +
                def supported_languages():
         | 
| 82 | 
            +
                    """Return a dict of language codes -> name supported by the backend"""
         | 
| 83 | 
            +
                    ...
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                def is_supported_language(self, language):
         | 
| 86 | 
            +
                    """Returns True if `language` is supported by the backend"""
         | 
| 87 | 
            +
                    return language in self.supported_languages()
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                @abc.abstractmethod
         | 
| 90 | 
            +
                def _phonemize(self, text, separator):
         | 
| 91 | 
            +
                    """The main phonemization method"""
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                def _phonemize_preprocess(self, text) -> Tuple[List[str], List]:
         | 
| 94 | 
            +
                    """Preprocess the text before phonemization
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                    1. remove spaces
         | 
| 97 | 
            +
                    2. remove punctuation
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                    Override this if you need a different behaviour
         | 
| 100 | 
            +
                    """
         | 
| 101 | 
            +
                    text = text.strip()
         | 
| 102 | 
            +
                    if self._keep_puncs:
         | 
| 103 | 
            +
                        # a tuple (text, punctuation marks)
         | 
| 104 | 
            +
                        return self._punctuator.strip_to_restore(text)
         | 
| 105 | 
            +
                    return [self._punctuator.strip(text)], []
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                def _phonemize_postprocess(self, phonemized, punctuations) -> str:
         | 
| 108 | 
            +
                    """Postprocess the raw phonemized output
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                    Override this if you need a different behaviour
         | 
| 111 | 
            +
                    """
         | 
| 112 | 
            +
                    if self._keep_puncs:
         | 
| 113 | 
            +
                        return self._punctuator.restore(phonemized, punctuations)[0]
         | 
| 114 | 
            +
                    return phonemized[0]
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                def phonemize(self, text: str, separator="|", language: str = None) -> str:  # pylint: disable=unused-argument
         | 
| 117 | 
            +
                    """Returns the `text` phonemized for the given language
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                    Args:
         | 
| 120 | 
            +
                        text (str):
         | 
| 121 | 
            +
                            Text to be phonemized.
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                        separator (str):
         | 
| 124 | 
            +
                            string separator used between phonemes. Default to '_'.
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                    Returns:
         | 
| 127 | 
            +
                        (str): Phonemized text
         | 
| 128 | 
            +
                    """
         | 
| 129 | 
            +
                    text, punctuations = self._phonemize_preprocess(text)
         | 
| 130 | 
            +
                    phonemized = []
         | 
| 131 | 
            +
                    for t in text:
         | 
| 132 | 
            +
                        p = self._phonemize(t, separator)
         | 
| 133 | 
            +
                        phonemized.append(p)
         | 
| 134 | 
            +
                    phonemized = self._phonemize_postprocess(phonemized, punctuations)
         | 
| 135 | 
            +
                    return phonemized
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                def print_logs(self, level: int = 0):
         | 
| 138 | 
            +
                    indent = "\t" * level
         | 
| 139 | 
            +
                    print(f"{indent}| > phoneme language: {self.language}")
         | 
| 140 | 
            +
                    print(f"{indent}| > phoneme backend: {self.name()}")
         | 
    	
        melo/text/es_phonemizer/cleaner.py
    ADDED
    
    | @@ -0,0 +1,109 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """Set of default text cleaners"""
         | 
| 2 | 
            +
            # TODO: pick the cleaner for languages dynamically
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import re
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            # Regular expression matching whitespace:
         | 
| 7 | 
            +
            _whitespace_re = re.compile(r"\s+")
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            rep_map = {
         | 
| 10 | 
            +
                ":": ",",
         | 
| 11 | 
            +
                ";": ",",
         | 
| 12 | 
            +
                ",": ",",
         | 
| 13 | 
            +
                "。": ".",
         | 
| 14 | 
            +
                "!": "!",
         | 
| 15 | 
            +
                "?": "?",
         | 
| 16 | 
            +
                "\n": ".",
         | 
| 17 | 
            +
                "·": ",",
         | 
| 18 | 
            +
                "、": ",",
         | 
| 19 | 
            +
                "...": ".",
         | 
| 20 | 
            +
                "…": ".",
         | 
| 21 | 
            +
                "$": ".",
         | 
| 22 | 
            +
                "“": "'",
         | 
| 23 | 
            +
                "”": "'",
         | 
| 24 | 
            +
                "‘": "'",
         | 
| 25 | 
            +
                "’": "'",
         | 
| 26 | 
            +
                "(": "'",
         | 
| 27 | 
            +
                ")": "'",
         | 
| 28 | 
            +
                "(": "'",
         | 
| 29 | 
            +
                ")": "'",
         | 
| 30 | 
            +
                "《": "'",
         | 
| 31 | 
            +
                "》": "'",
         | 
| 32 | 
            +
                "【": "'",
         | 
| 33 | 
            +
                "】": "'",
         | 
| 34 | 
            +
                "[": "'",
         | 
| 35 | 
            +
                "]": "'",
         | 
| 36 | 
            +
                "—": "",
         | 
| 37 | 
            +
                "~": "-",
         | 
| 38 | 
            +
                "~": "-",
         | 
| 39 | 
            +
                "「": "'",
         | 
| 40 | 
            +
                "」": "'",
         | 
| 41 | 
            +
            }
         | 
| 42 | 
            +
             | 
| 43 | 
            +
            def replace_punctuation(text):
         | 
| 44 | 
            +
                pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
         | 
| 45 | 
            +
                replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
         | 
| 46 | 
            +
                return replaced_text
         | 
| 47 | 
            +
             | 
| 48 | 
            +
            def lowercase(text):
         | 
| 49 | 
            +
                return text.lower()
         | 
| 50 | 
            +
             | 
| 51 | 
            +
             | 
| 52 | 
            +
            def collapse_whitespace(text):
         | 
| 53 | 
            +
                return re.sub(_whitespace_re, " ", text).strip()
         | 
| 54 | 
            +
             | 
| 55 | 
            +
            def remove_punctuation_at_begin(text):
         | 
| 56 | 
            +
                return re.sub(r'^[,.!?]+', '', text)
         | 
| 57 | 
            +
             | 
| 58 | 
            +
            def remove_aux_symbols(text):
         | 
| 59 | 
            +
                text = re.sub(r"[\<\>\(\)\[\]\"\«\»\']+", "", text)
         | 
| 60 | 
            +
                return text
         | 
| 61 | 
            +
             | 
| 62 | 
            +
             | 
| 63 | 
            +
            def replace_symbols(text, lang="en"):
         | 
| 64 | 
            +
                """Replace symbols based on the lenguage tag.
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                Args:
         | 
| 67 | 
            +
                  text:
         | 
| 68 | 
            +
                   Input text.
         | 
| 69 | 
            +
                  lang:
         | 
| 70 | 
            +
                    Lenguage identifier. ex: "en", "fr", "pt", "ca".
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                Returns:
         | 
| 73 | 
            +
                  The modified text
         | 
| 74 | 
            +
                  example:
         | 
| 75 | 
            +
                    input args:
         | 
| 76 | 
            +
                        text: "si l'avi cau, diguem-ho"
         | 
| 77 | 
            +
                        lang: "ca"
         | 
| 78 | 
            +
                    Output:
         | 
| 79 | 
            +
                        text: "si lavi cau, diguemho"
         | 
| 80 | 
            +
                """
         | 
| 81 | 
            +
                text = text.replace(";", ",")
         | 
| 82 | 
            +
                text = text.replace("-", " ") if lang != "ca" else text.replace("-", "")
         | 
| 83 | 
            +
                text = text.replace(":", ",")
         | 
| 84 | 
            +
                if lang == "en":
         | 
| 85 | 
            +
                    text = text.replace("&", " and ")
         | 
| 86 | 
            +
                elif lang == "fr":
         | 
| 87 | 
            +
                    text = text.replace("&", " et ")
         | 
| 88 | 
            +
                elif lang == "pt":
         | 
| 89 | 
            +
                    text = text.replace("&", " e ")
         | 
| 90 | 
            +
                elif lang == "ca":
         | 
| 91 | 
            +
                    text = text.replace("&", " i ")
         | 
| 92 | 
            +
                    text = text.replace("'", "")
         | 
| 93 | 
            +
                elif lang== "es":
         | 
| 94 | 
            +
                    text=text.replace("&","y")
         | 
| 95 | 
            +
                    text = text.replace("'", "")
         | 
| 96 | 
            +
                return text
         | 
| 97 | 
            +
             | 
| 98 | 
            +
            def spanish_cleaners(text):
         | 
| 99 | 
            +
                """Basic pipeline for Portuguese text. There is no need to expand abbreviation and
         | 
| 100 | 
            +
                numbers, phonemizer already does that"""
         | 
| 101 | 
            +
                text = lowercase(text)
         | 
| 102 | 
            +
                text = replace_symbols(text, lang="es")
         | 
| 103 | 
            +
                text = replace_punctuation(text)
         | 
| 104 | 
            +
                text = remove_aux_symbols(text)
         | 
| 105 | 
            +
                text = remove_punctuation_at_begin(text)
         | 
| 106 | 
            +
                text = collapse_whitespace(text)
         | 
| 107 | 
            +
                text = re.sub(r'([^\.,!\?\-…])$', r'\1.', text)
         | 
| 108 | 
            +
                return text
         | 
| 109 | 
            +
             | 
    	
        melo/text/es_phonemizer/es_symbols.json
    ADDED
    
    | @@ -0,0 +1,79 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
                "symbols": [
         | 
| 3 | 
            +
                    "_",
         | 
| 4 | 
            +
                    ",",
         | 
| 5 | 
            +
                    ".",
         | 
| 6 | 
            +
                    "!",
         | 
| 7 | 
            +
                    "?",
         | 
| 8 | 
            +
                    "-",
         | 
| 9 | 
            +
                    "~",
         | 
| 10 | 
            +
                    "\u2026",
         | 
| 11 | 
            +
                    "N",
         | 
| 12 | 
            +
                    "Q",
         | 
| 13 | 
            +
                    "a",
         | 
| 14 | 
            +
                    "b",
         | 
| 15 | 
            +
                    "d",
         | 
| 16 | 
            +
                    "e",
         | 
| 17 | 
            +
                    "f",
         | 
| 18 | 
            +
                    "g",
         | 
| 19 | 
            +
                    "h",
         | 
| 20 | 
            +
                    "i",
         | 
| 21 | 
            +
                    "j",
         | 
| 22 | 
            +
                    "k",
         | 
| 23 | 
            +
                    "l",
         | 
| 24 | 
            +
                    "m",
         | 
| 25 | 
            +
                    "n",
         | 
| 26 | 
            +
                    "o",
         | 
| 27 | 
            +
                    "p",
         | 
| 28 | 
            +
                    "s",
         | 
| 29 | 
            +
                    "t",
         | 
| 30 | 
            +
                    "u",
         | 
| 31 | 
            +
                    "v",
         | 
| 32 | 
            +
                    "w",
         | 
| 33 | 
            +
                    "x",
         | 
| 34 | 
            +
                    "y",
         | 
| 35 | 
            +
                    "z",
         | 
| 36 | 
            +
                    "\u0251",
         | 
| 37 | 
            +
                    "\u00e6",
         | 
| 38 | 
            +
                    "\u0283",
         | 
| 39 | 
            +
                    "\u0291",
         | 
| 40 | 
            +
                    "\u00e7",
         | 
| 41 | 
            +
                    "\u026f",
         | 
| 42 | 
            +
                    "\u026a",
         | 
| 43 | 
            +
                    "\u0254",
         | 
| 44 | 
            +
                    "\u025b",
         | 
| 45 | 
            +
                    "\u0279",
         | 
| 46 | 
            +
                    "\u00f0",
         | 
| 47 | 
            +
                    "\u0259",
         | 
| 48 | 
            +
                    "\u026b",
         | 
| 49 | 
            +
                    "\u0265",
         | 
| 50 | 
            +
                    "\u0278",
         | 
| 51 | 
            +
                    "\u028a",
         | 
| 52 | 
            +
                    "\u027e",
         | 
| 53 | 
            +
                    "\u0292",
         | 
| 54 | 
            +
                    "\u03b8",
         | 
| 55 | 
            +
                    "\u03b2",
         | 
| 56 | 
            +
                    "\u014b",
         | 
| 57 | 
            +
                    "\u0266",
         | 
| 58 | 
            +
                    "\u207c",
         | 
| 59 | 
            +
                    "\u02b0",
         | 
| 60 | 
            +
                    "`",
         | 
| 61 | 
            +
                    "^",
         | 
| 62 | 
            +
                    "#",
         | 
| 63 | 
            +
                    "*",
         | 
| 64 | 
            +
                    "=",
         | 
| 65 | 
            +
                    "\u02c8",
         | 
| 66 | 
            +
                    "\u02cc",
         | 
| 67 | 
            +
                    "\u2192",
         | 
| 68 | 
            +
                    "\u2193",
         | 
| 69 | 
            +
                    "\u2191",
         | 
| 70 | 
            +
                    " ",
         | 
| 71 | 
            +
                    "\u0263",
         | 
| 72 | 
            +
                    "\u0261",
         | 
| 73 | 
            +
                    "r",
         | 
| 74 | 
            +
                    "\u0272",
         | 
| 75 | 
            +
                    "\u029d",
         | 
| 76 | 
            +
                    "\u028e",
         | 
| 77 | 
            +
                    "\u02d0"
         | 
| 78 | 
            +
                ]
         | 
| 79 | 
            +
            }
         | 
    	
        melo/text/es_phonemizer/es_symbols.txt
    ADDED
    
    | @@ -0,0 +1 @@ | |
|  | 
|  | |
| 1 | 
            +
            _,.!?-~…NQabdefghijklmnopstuvwxyzɑæʃʑçɯɪɔɛɹðəɫɥɸʊɾʒθβŋɦ⁼ʰ`^#*=ˈˌ→↓↑ ɡrɲʝɣʎː—¿¡
         | 
    	
        melo/text/es_phonemizer/es_symbols_v2.json
    ADDED
    
    | @@ -0,0 +1,83 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
                "symbols": [
         | 
| 3 | 
            +
                    "_",
         | 
| 4 | 
            +
                    ",",
         | 
| 5 | 
            +
                    ".",
         | 
| 6 | 
            +
                    "!",
         | 
| 7 | 
            +
                    "?",
         | 
| 8 | 
            +
                    "-",
         | 
| 9 | 
            +
                    "~",
         | 
| 10 | 
            +
                    "\u2026",
         | 
| 11 | 
            +
                    "N",
         | 
| 12 | 
            +
                    "Q",
         | 
| 13 | 
            +
                    "a",
         | 
| 14 | 
            +
                    "b",
         | 
| 15 | 
            +
                    "d",
         | 
| 16 | 
            +
                    "e",
         | 
| 17 | 
            +
                    "f",
         | 
| 18 | 
            +
                    "g",
         | 
| 19 | 
            +
                    "h",
         | 
| 20 | 
            +
                    "i",
         | 
| 21 | 
            +
                    "j",
         | 
| 22 | 
            +
                    "k",
         | 
| 23 | 
            +
                    "l",
         | 
| 24 | 
            +
                    "m",
         | 
| 25 | 
            +
                    "n",
         | 
| 26 | 
            +
                    "o",
         | 
| 27 | 
            +
                    "p",
         | 
| 28 | 
            +
                    "s",
         | 
| 29 | 
            +
                    "t",
         | 
| 30 | 
            +
                    "u",
         | 
| 31 | 
            +
                    "v",
         | 
| 32 | 
            +
                    "w",
         | 
| 33 | 
            +
                    "x",
         | 
| 34 | 
            +
                    "y",
         | 
| 35 | 
            +
                    "z",
         | 
| 36 | 
            +
                    "\u0251",
         | 
| 37 | 
            +
                    "\u00e6",
         | 
| 38 | 
            +
                    "\u0283",
         | 
| 39 | 
            +
                    "\u0291",
         | 
| 40 | 
            +
                    "\u00e7",
         | 
| 41 | 
            +
                    "\u026f",
         | 
| 42 | 
            +
                    "\u026a",
         | 
| 43 | 
            +
                    "\u0254",
         | 
| 44 | 
            +
                    "\u025b",
         | 
| 45 | 
            +
                    "\u0279",
         | 
| 46 | 
            +
                    "\u00f0",
         | 
| 47 | 
            +
                    "\u0259",
         | 
| 48 | 
            +
                    "\u026b",
         | 
| 49 | 
            +
                    "\u0265",
         | 
| 50 | 
            +
                    "\u0278",
         | 
| 51 | 
            +
                    "\u028a",
         | 
| 52 | 
            +
                    "\u027e",
         | 
| 53 | 
            +
                    "\u0292",
         | 
| 54 | 
            +
                    "\u03b8",
         | 
| 55 | 
            +
                    "\u03b2",
         | 
| 56 | 
            +
                    "\u014b",
         | 
| 57 | 
            +
                    "\u0266",
         | 
| 58 | 
            +
                    "\u207c",
         | 
| 59 | 
            +
                    "\u02b0",
         | 
| 60 | 
            +
                    "`",
         | 
| 61 | 
            +
                    "^",
         | 
| 62 | 
            +
                    "#",
         | 
| 63 | 
            +
                    "*",
         | 
| 64 | 
            +
                    "=",
         | 
| 65 | 
            +
                    "\u02c8",
         | 
| 66 | 
            +
                    "\u02cc",
         | 
| 67 | 
            +
                    "\u2192",
         | 
| 68 | 
            +
                    "\u2193",
         | 
| 69 | 
            +
                    "\u2191",
         | 
| 70 | 
            +
                    " ",
         | 
| 71 | 
            +
                    "\u0261",
         | 
| 72 | 
            +
                    "r",
         | 
| 73 | 
            +
                    "\u0272",
         | 
| 74 | 
            +
                    "\u029d",
         | 
| 75 | 
            +
                    "\u0263",
         | 
| 76 | 
            +
                    "\u028e",
         | 
| 77 | 
            +
                    "\u02d0",
         | 
| 78 | 
            +
                    
         | 
| 79 | 
            +
                    "\u2014",
         | 
| 80 | 
            +
                    "\u00bf",
         | 
| 81 | 
            +
                    "\u00a1"
         | 
| 82 | 
            +
                ]
         | 
| 83 | 
            +
            }
         | 
    	
        melo/text/es_phonemizer/es_to_ipa.py
    ADDED
    
    | @@ -0,0 +1,12 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .cleaner import spanish_cleaners
         | 
| 2 | 
            +
            from .gruut_wrapper import Gruut
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            def es2ipa(text):
         | 
| 5 | 
            +
                e = Gruut(language="es-es", keep_puncs=True, keep_stress=True, use_espeak_phonemes=True)
         | 
| 6 | 
            +
                # text = spanish_cleaners(text)
         | 
| 7 | 
            +
                phonemes = e.phonemize(text, separator="")
         | 
| 8 | 
            +
                return phonemes
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            if __name__ == '__main__':
         | 
| 12 | 
            +
              print(es2ipa('¿Y a quién echaría de menos, en el mundo si no fuese a vos?'))
         | 
    	
        melo/text/es_phonemizer/gruut_wrapper.py
    ADDED
    
    | @@ -0,0 +1,253 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import importlib
         | 
| 2 | 
            +
            from typing import List
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import gruut
         | 
| 5 | 
            +
            from gruut_ipa import IPA # pip install gruut_ipa
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from .base import BasePhonemizer
         | 
| 8 | 
            +
            from .punctuation import Punctuation
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            # Table for str.translate to fix gruut/TTS phoneme mismatch
         | 
| 11 | 
            +
            GRUUT_TRANS_TABLE = str.maketrans("g", "ɡ")
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            class Gruut(BasePhonemizer):
         | 
| 15 | 
            +
                """Gruut wrapper for G2P
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                Args:
         | 
| 18 | 
            +
                    language (str):
         | 
| 19 | 
            +
                        Valid language code for the used backend.
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                    punctuations (str):
         | 
| 22 | 
            +
                        Characters to be treated as punctuation. Defaults to `Punctuation.default_puncs()`.
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                    keep_puncs (bool):
         | 
| 25 | 
            +
                        If true, keep the punctuations after phonemization. Defaults to True.
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                    use_espeak_phonemes (bool):
         | 
| 28 | 
            +
                        If true, use espeak lexicons instead of default Gruut lexicons. Defaults to False.
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                    keep_stress (bool):
         | 
| 31 | 
            +
                        If true, keep the stress characters after phonemization. Defaults to False.
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                Example:
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                    >>> from TTS.tts.utils.text.phonemizers.gruut_wrapper import Gruut
         | 
| 36 | 
            +
                    >>> phonemizer = Gruut('en-us')
         | 
| 37 | 
            +
                    >>> phonemizer.phonemize("Be a voice, not an! echo?", separator="|")
         | 
| 38 | 
            +
                    'b|i| ə| v|ɔ|ɪ|s, n|ɑ|t| ə|n! ɛ|k|o|ʊ?'
         | 
| 39 | 
            +
                """
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                def __init__(
         | 
| 42 | 
            +
                    self,
         | 
| 43 | 
            +
                    language: str,
         | 
| 44 | 
            +
                    punctuations=Punctuation.default_puncs(),
         | 
| 45 | 
            +
                    keep_puncs=True,
         | 
| 46 | 
            +
                    use_espeak_phonemes=False,
         | 
| 47 | 
            +
                    keep_stress=False,
         | 
| 48 | 
            +
                ):
         | 
| 49 | 
            +
                    super().__init__(language, punctuations=punctuations, keep_puncs=keep_puncs)
         | 
| 50 | 
            +
                    self.use_espeak_phonemes = use_espeak_phonemes
         | 
| 51 | 
            +
                    self.keep_stress = keep_stress
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                @staticmethod
         | 
| 54 | 
            +
                def name():
         | 
| 55 | 
            +
                    return "gruut"
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                def phonemize_gruut(self, text: str, separator: str = "|", tie=False) -> str:  # pylint: disable=unused-argument
         | 
| 58 | 
            +
                    """Convert input text to phonemes.
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                    Gruut phonemizes the given `str` by seperating each phoneme character with `separator`, even for characters
         | 
| 61 | 
            +
                    that constitude a single sound.
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                    It doesn't affect 🐸TTS since it individually converts each character to token IDs.
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                    Examples::
         | 
| 66 | 
            +
                        "hello how are you today?" -> `h|ɛ|l|o|ʊ| h|a|ʊ| ɑ|ɹ| j|u| t|ə|d|e|ɪ`
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    Args:
         | 
| 69 | 
            +
                        text (str):
         | 
| 70 | 
            +
                            Text to be converted to phonemes.
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                        tie (bool, optional) : When True use a '͡' character between
         | 
| 73 | 
            +
                            consecutive characters of a single phoneme. Else separate phoneme
         | 
| 74 | 
            +
                            with '_'. This option requires espeak>=1.49. Default to False.
         | 
| 75 | 
            +
                    """
         | 
| 76 | 
            +
                    ph_list = []
         | 
| 77 | 
            +
                    for sentence in gruut.sentences(text, lang=self.language, espeak=self.use_espeak_phonemes):
         | 
| 78 | 
            +
                        for word in sentence:
         | 
| 79 | 
            +
                            if word.is_break:
         | 
| 80 | 
            +
                                # Use actual character for break phoneme (e.g., comma)
         | 
| 81 | 
            +
                                if ph_list:
         | 
| 82 | 
            +
                                    # Join with previous word
         | 
| 83 | 
            +
                                    ph_list[-1].append(word.text)
         | 
| 84 | 
            +
                                else:
         | 
| 85 | 
            +
                                    # First word is punctuation
         | 
| 86 | 
            +
                                    ph_list.append([word.text])
         | 
| 87 | 
            +
                            elif word.phonemes:
         | 
| 88 | 
            +
                                # Add phonemes for word
         | 
| 89 | 
            +
                                word_phonemes = []
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                                for word_phoneme in word.phonemes:
         | 
| 92 | 
            +
                                    if not self.keep_stress:
         | 
| 93 | 
            +
                                        # Remove primary/secondary stress
         | 
| 94 | 
            +
                                        word_phoneme = IPA.without_stress(word_phoneme)
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                                    word_phoneme = word_phoneme.translate(GRUUT_TRANS_TABLE)
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                                    if word_phoneme:
         | 
| 99 | 
            +
                                        # Flatten phonemes
         | 
| 100 | 
            +
                                        word_phonemes.extend(word_phoneme)
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                                if word_phonemes:
         | 
| 103 | 
            +
                                    ph_list.append(word_phonemes)
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                    ph_words = [separator.join(word_phonemes) for word_phonemes in ph_list]
         | 
| 106 | 
            +
                    ph = f"{separator} ".join(ph_words)
         | 
| 107 | 
            +
                    return ph
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                def _phonemize(self, text, separator):
         | 
| 110 | 
            +
                    return self.phonemize_gruut(text, separator, tie=False)
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                def is_supported_language(self, language):
         | 
| 113 | 
            +
                    """Returns True if `language` is supported by the backend"""
         | 
| 114 | 
            +
                    return gruut.is_language_supported(language)
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                @staticmethod
         | 
| 117 | 
            +
                def supported_languages() -> List:
         | 
| 118 | 
            +
                    """Get a dictionary of supported languages.
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                    Returns:
         | 
| 121 | 
            +
                        List: List of language codes.
         | 
| 122 | 
            +
                    """
         | 
| 123 | 
            +
                    return list(gruut.get_supported_languages())
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                def version(self):
         | 
| 126 | 
            +
                    """Get the version of the used backend.
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                    Returns:
         | 
| 129 | 
            +
                        str: Version of the used backend.
         | 
| 130 | 
            +
                    """
         | 
| 131 | 
            +
                    return gruut.__version__
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                @classmethod
         | 
| 134 | 
            +
                def is_available(cls):
         | 
| 135 | 
            +
                    """Return true if ESpeak is available else false"""
         | 
| 136 | 
            +
                    return importlib.util.find_spec("gruut") is not None
         | 
| 137 | 
            +
             | 
| 138 | 
            +
             | 
| 139 | 
            +
            if __name__ == "__main__":
         | 
| 140 | 
            +
                from es_to_ipa import es2ipa
         | 
| 141 | 
            +
                import json
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                e = Gruut(language="es-es", keep_puncs=True, keep_stress=True, use_espeak_phonemes=True)
         | 
| 144 | 
            +
                symbols = [
         | 
| 145 | 
            +
                    "_",
         | 
| 146 | 
            +
                    ",",
         | 
| 147 | 
            +
                    ".",
         | 
| 148 | 
            +
                    "!",
         | 
| 149 | 
            +
                    "?",
         | 
| 150 | 
            +
                    "-",
         | 
| 151 | 
            +
                    "~",
         | 
| 152 | 
            +
                    "\u2026",
         | 
| 153 | 
            +
                    "N",
         | 
| 154 | 
            +
                    "Q",
         | 
| 155 | 
            +
                    "a",
         | 
| 156 | 
            +
                    "b",
         | 
| 157 | 
            +
                    "d",
         | 
| 158 | 
            +
                    "e",
         | 
| 159 | 
            +
                    "f",
         | 
| 160 | 
            +
                    "g",
         | 
| 161 | 
            +
                    "h",
         | 
| 162 | 
            +
                    "i",
         | 
| 163 | 
            +
                    "j",
         | 
| 164 | 
            +
                    "k",
         | 
| 165 | 
            +
                    "l",
         | 
| 166 | 
            +
                    "m",
         | 
| 167 | 
            +
                    "n",
         | 
| 168 | 
            +
                    "o",
         | 
| 169 | 
            +
                    "p",
         | 
| 170 | 
            +
                    "s",
         | 
| 171 | 
            +
                    "t",
         | 
| 172 | 
            +
                    "u",
         | 
| 173 | 
            +
                    "v",
         | 
| 174 | 
            +
                    "w",
         | 
| 175 | 
            +
                    "x",
         | 
| 176 | 
            +
                    "y",
         | 
| 177 | 
            +
                    "z",
         | 
| 178 | 
            +
                    "\u0251",
         | 
| 179 | 
            +
                    "\u00e6",
         | 
| 180 | 
            +
                    "\u0283",
         | 
| 181 | 
            +
                    "\u0291",
         | 
| 182 | 
            +
                    "\u00e7",
         | 
| 183 | 
            +
                    "\u026f",
         | 
| 184 | 
            +
                    "\u026a",
         | 
| 185 | 
            +
                    "\u0254",
         | 
| 186 | 
            +
                    "\u025b",
         | 
| 187 | 
            +
                    "\u0279",
         | 
| 188 | 
            +
                    "\u00f0",
         | 
| 189 | 
            +
                    "\u0259",
         | 
| 190 | 
            +
                    "\u026b",
         | 
| 191 | 
            +
                    "\u0265",
         | 
| 192 | 
            +
                    "\u0278",
         | 
| 193 | 
            +
                    "\u028a",
         | 
| 194 | 
            +
                    "\u027e",
         | 
| 195 | 
            +
                    "\u0292",
         | 
| 196 | 
            +
                    "\u03b8",
         | 
| 197 | 
            +
                    "\u03b2",
         | 
| 198 | 
            +
                    "\u014b",
         | 
| 199 | 
            +
                    "\u0266",
         | 
| 200 | 
            +
                    "\u207c",
         | 
| 201 | 
            +
                    "\u02b0",
         | 
| 202 | 
            +
                    "`",
         | 
| 203 | 
            +
                    "^",
         | 
| 204 | 
            +
                    "#",
         | 
| 205 | 
            +
                    "*",
         | 
| 206 | 
            +
                    "=",
         | 
| 207 | 
            +
                    "\u02c8",
         | 
| 208 | 
            +
                    "\u02cc",
         | 
| 209 | 
            +
                    "\u2192",
         | 
| 210 | 
            +
                    "\u2193",
         | 
| 211 | 
            +
                    "\u2191",
         | 
| 212 | 
            +
                    " ",
         | 
| 213 | 
            +
                ]
         | 
| 214 | 
            +
                with open('./text/es_phonemizer/spanish_text.txt', 'r') as f:
         | 
| 215 | 
            +
                    lines = f.readlines()
         | 
| 216 | 
            +
                
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                used_sym = []
         | 
| 219 | 
            +
                not_existed_sym = []
         | 
| 220 | 
            +
                phonemes = []
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                for line in lines[:400]:
         | 
| 223 | 
            +
                    text = line.split('|')[-1].strip()
         | 
| 224 | 
            +
                    ipa = es2ipa(text)
         | 
| 225 | 
            +
                    phonemes.append(ipa + '\n')
         | 
| 226 | 
            +
                    for s in ipa:
         | 
| 227 | 
            +
                        if s not in symbols:
         | 
| 228 | 
            +
                            if s not in not_existed_sym:
         | 
| 229 | 
            +
                                print(f'not_existed char: {s}')
         | 
| 230 | 
            +
                                not_existed_sym.append(s)
         | 
| 231 | 
            +
                        else:
         | 
| 232 | 
            +
                            if s not in used_sym:
         | 
| 233 | 
            +
                                # print(f'used char: {s}')
         | 
| 234 | 
            +
                                used_sym.append(s)
         | 
| 235 | 
            +
                
         | 
| 236 | 
            +
                print(used_sym)
         | 
| 237 | 
            +
                print(not_existed_sym)
         | 
| 238 | 
            +
             | 
| 239 | 
            +
             | 
| 240 | 
            +
                with open('./text/es_phonemizer/es_symbols.txt', 'w') as g:
         | 
| 241 | 
            +
                    g.writelines(symbols + not_existed_sym)
         | 
| 242 | 
            +
                    
         | 
| 243 | 
            +
                with open('./text/es_phonemizer/example_ipa.txt', 'w') as g:
         | 
| 244 | 
            +
                    g.writelines(phonemes)
         | 
| 245 | 
            +
                    
         | 
| 246 | 
            +
                data = {'symbols': symbols + not_existed_sym}
         | 
| 247 | 
            +
                with open('./text/es_phonemizer/es_symbols_v2.json', 'w') as f:
         | 
| 248 | 
            +
                    json.dump(data, f, indent=4)
         | 
| 249 | 
            +
             | 
| 250 | 
            +
             | 
| 251 | 
            +
             | 
| 252 | 
            +
             | 
| 253 | 
            +
             | 
    	
        melo/text/es_phonemizer/punctuation.py
    ADDED
    
    | @@ -0,0 +1,174 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import collections
         | 
| 2 | 
            +
            import re
         | 
| 3 | 
            +
            from enum import Enum
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import six
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            _DEF_PUNCS = ';:,.!?¡¿—…"«»“”'
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            _PUNC_IDX = collections.namedtuple("_punc_index", ["punc", "position"])
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            class PuncPosition(Enum):
         | 
| 13 | 
            +
                """Enum for the punctuations positions"""
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                BEGIN = 0
         | 
| 16 | 
            +
                END = 1
         | 
| 17 | 
            +
                MIDDLE = 2
         | 
| 18 | 
            +
                ALONE = 3
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            class Punctuation:
         | 
| 22 | 
            +
                """Handle punctuations in text.
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                Just strip punctuations from text or strip and restore them later.
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                Args:
         | 
| 27 | 
            +
                    puncs (str): The punctuations to be processed. Defaults to `_DEF_PUNCS`.
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                Example:
         | 
| 30 | 
            +
                    >>> punc = Punctuation()
         | 
| 31 | 
            +
                    >>> punc.strip("This is. example !")
         | 
| 32 | 
            +
                    'This is example'
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                    >>> text_striped, punc_map = punc.strip_to_restore("This is. example !")
         | 
| 35 | 
            +
                    >>> ' '.join(text_striped)
         | 
| 36 | 
            +
                    'This is example'
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                    >>> text_restored = punc.restore(text_striped, punc_map)
         | 
| 39 | 
            +
                    >>> text_restored[0]
         | 
| 40 | 
            +
                    'This is. example !'
         | 
| 41 | 
            +
                """
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                def __init__(self, puncs: str = _DEF_PUNCS):
         | 
| 44 | 
            +
                    self.puncs = puncs
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                @staticmethod
         | 
| 47 | 
            +
                def default_puncs():
         | 
| 48 | 
            +
                    """Return default set of punctuations."""
         | 
| 49 | 
            +
                    return _DEF_PUNCS
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                @property
         | 
| 52 | 
            +
                def puncs(self):
         | 
| 53 | 
            +
                    return self._puncs
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                @puncs.setter
         | 
| 56 | 
            +
                def puncs(self, value):
         | 
| 57 | 
            +
                    if not isinstance(value, six.string_types):
         | 
| 58 | 
            +
                        raise ValueError("[!] Punctuations must be of type str.")
         | 
| 59 | 
            +
                    self._puncs = "".join(list(dict.fromkeys(list(value))))  # remove duplicates without changing the oreder
         | 
| 60 | 
            +
                    self.puncs_regular_exp = re.compile(rf"(\s*[{re.escape(self._puncs)}]+\s*)+")
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                def strip(self, text):
         | 
| 63 | 
            +
                    """Remove all the punctuations by replacing with `space`.
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                    Args:
         | 
| 66 | 
            +
                        text (str): The text to be processed.
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    Example::
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                        "This is. example !" -> "This is example "
         | 
| 71 | 
            +
                    """
         | 
| 72 | 
            +
                    return re.sub(self.puncs_regular_exp, " ", text).rstrip().lstrip()
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                def strip_to_restore(self, text):
         | 
| 75 | 
            +
                    """Remove punctuations from text to restore them later.
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                    Args:
         | 
| 78 | 
            +
                        text (str): The text to be processed.
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    Examples ::
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                        "This is. example !" -> [["This is", "example"], [".", "!"]]
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                    """
         | 
| 85 | 
            +
                    text, puncs = self._strip_to_restore(text)
         | 
| 86 | 
            +
                    return text, puncs
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                def _strip_to_restore(self, text):
         | 
| 89 | 
            +
                    """Auxiliary method for Punctuation.preserve()"""
         | 
| 90 | 
            +
                    matches = list(re.finditer(self.puncs_regular_exp, text))
         | 
| 91 | 
            +
                    if not matches:
         | 
| 92 | 
            +
                        return [text], []
         | 
| 93 | 
            +
                    # the text is only punctuations
         | 
| 94 | 
            +
                    if len(matches) == 1 and matches[0].group() == text:
         | 
| 95 | 
            +
                        return [], [_PUNC_IDX(text, PuncPosition.ALONE)]
         | 
| 96 | 
            +
                    # build a punctuation map to be used later to restore punctuations
         | 
| 97 | 
            +
                    puncs = []
         | 
| 98 | 
            +
                    for match in matches:
         | 
| 99 | 
            +
                        position = PuncPosition.MIDDLE
         | 
| 100 | 
            +
                        if match == matches[0] and text.startswith(match.group()):
         | 
| 101 | 
            +
                            position = PuncPosition.BEGIN
         | 
| 102 | 
            +
                        elif match == matches[-1] and text.endswith(match.group()):
         | 
| 103 | 
            +
                            position = PuncPosition.END
         | 
| 104 | 
            +
                        puncs.append(_PUNC_IDX(match.group(), position))
         | 
| 105 | 
            +
                    # convert str text to a List[str], each item is separated by a punctuation
         | 
| 106 | 
            +
                    splitted_text = []
         | 
| 107 | 
            +
                    for idx, punc in enumerate(puncs):
         | 
| 108 | 
            +
                        split = text.split(punc.punc)
         | 
| 109 | 
            +
                        prefix, suffix = split[0], punc.punc.join(split[1:])
         | 
| 110 | 
            +
                        splitted_text.append(prefix)
         | 
| 111 | 
            +
                        # if the text does not end with a punctuation, add it to the last item
         | 
| 112 | 
            +
                        if idx == len(puncs) - 1 and len(suffix) > 0:
         | 
| 113 | 
            +
                            splitted_text.append(suffix)
         | 
| 114 | 
            +
                        text = suffix
         | 
| 115 | 
            +
                    while splitted_text[0] == '':
         | 
| 116 | 
            +
                        splitted_text = splitted_text[1:]
         | 
| 117 | 
            +
                    return splitted_text, puncs
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                @classmethod
         | 
| 120 | 
            +
                def restore(cls, text, puncs):
         | 
| 121 | 
            +
                    """Restore punctuation in a text.
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                    Args:
         | 
| 124 | 
            +
                        text (str): The text to be processed.
         | 
| 125 | 
            +
                        puncs (List[str]): The list of punctuations map to be used for restoring.
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                    Examples ::
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                        ['This is', 'example'], ['.', '!'] -> "This is. example!"
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                    """
         | 
| 132 | 
            +
                    return cls._restore(text, puncs, 0)
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                @classmethod
         | 
| 135 | 
            +
                def _restore(cls, text, puncs, num):  # pylint: disable=too-many-return-statements
         | 
| 136 | 
            +
                    """Auxiliary method for Punctuation.restore()"""
         | 
| 137 | 
            +
                    if not puncs:
         | 
| 138 | 
            +
                        return text
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                    # nothing have been phonemized, returns the puncs alone
         | 
| 141 | 
            +
                    if not text:
         | 
| 142 | 
            +
                        return ["".join(m.punc for m in puncs)]
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                    current = puncs[0]
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                    if current.position == PuncPosition.BEGIN:
         | 
| 147 | 
            +
                        return cls._restore([current.punc + text[0]] + text[1:], puncs[1:], num)
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                    if current.position == PuncPosition.END:
         | 
| 150 | 
            +
                        return [text[0] + current.punc] + cls._restore(text[1:], puncs[1:], num + 1)
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                    if current.position == PuncPosition.ALONE:
         | 
| 153 | 
            +
                        return [current.mark] + cls._restore(text, puncs[1:], num + 1)
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                    # POSITION == MIDDLE
         | 
| 156 | 
            +
                    if len(text) == 1:  # pragma: nocover
         | 
| 157 | 
            +
                        # a corner case where the final part of an intermediate
         | 
| 158 | 
            +
                        # mark (I) has not been phonemized
         | 
| 159 | 
            +
                        return cls._restore([text[0] + current.punc], puncs[1:], num)
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                    return cls._restore([text[0] + current.punc + text[1]] + text[2:], puncs[1:], num)
         | 
| 162 | 
            +
             | 
| 163 | 
            +
             | 
| 164 | 
            +
            # if __name__ == "__main__":
         | 
| 165 | 
            +
            #     punc = Punctuation()
         | 
| 166 | 
            +
            #     text = "This is. This is, example!"
         | 
| 167 | 
            +
             | 
| 168 | 
            +
            #     print(punc.strip(text))
         | 
| 169 | 
            +
             | 
| 170 | 
            +
            #     split_text, puncs = punc.strip_to_restore(text)
         | 
| 171 | 
            +
            #     print(split_text, " ---- ", puncs)
         | 
| 172 | 
            +
             | 
| 173 | 
            +
            #     restored_text = punc.restore(split_text, puncs)
         | 
| 174 | 
            +
            #     print(restored_text)
         | 
    	
        melo/text/es_phonemizer/spanish_symbols.txt
    ADDED
    
    | @@ -0,0 +1 @@ | |
|  | 
|  | |
| 1 | 
            +
            dˌaβˈiðkopeɾfjl unθsbmtʃwɛxɪŋʊɣɡrɲʝʎː
         | 
    	
        melo/text/es_phonemizer/test.ipynb
    ADDED
    
    | @@ -0,0 +1,124 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
             "cells": [
         | 
| 3 | 
            +
              {
         | 
| 4 | 
            +
               "cell_type": "code",
         | 
| 5 | 
            +
               "execution_count": 1,
         | 
| 6 | 
            +
               "metadata": {},
         | 
| 7 | 
            +
               "outputs": [
         | 
| 8 | 
            +
                {
         | 
| 9 | 
            +
                 "ename": "ImportError",
         | 
| 10 | 
            +
                 "evalue": "attempted relative import with no known parent package",
         | 
| 11 | 
            +
                 "output_type": "error",
         | 
| 12 | 
            +
                 "traceback": [
         | 
| 13 | 
            +
                  "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
         | 
| 14 | 
            +
                  "\u001b[0;31mImportError\u001b[0m                               Traceback (most recent call last)",
         | 
| 15 | 
            +
                  "\u001b[1;32m/home/xumin/workspace/Bert-VITS2/text/es_phonemizer/test.ipynb Cell 1\u001b[0m line \u001b[0;36m5\n\u001b[1;32m      <a href='vscode-notebook-cell://ssh-remote%2Bcatams4/home/xumin/workspace/Bert-VITS2/text/es_phonemizer/test.ipynb#W0sdnNjb2RlLXJlbW90ZQ%3D%3D?line=2'>3</a>\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mos\u001b[39;00m\u001b[39m,\u001b[39m \u001b[39msys\u001b[39;00m\n\u001b[1;32m      <a href='vscode-notebook-cell://ssh-remote%2Bcatams4/home/xumin/workspace/Bert-VITS2/text/es_phonemizer/test.ipynb#W0sdnNjb2RlLXJlbW90ZQ%3D%3D?line=3'>4</a>\u001b[0m sys\u001b[39m.\u001b[39mpath\u001b[39m.\u001b[39mappend(\u001b[39m'\u001b[39m\u001b[39m/home/xumin/workspace/MyShell-VC-Training/text/es_phonemizer/\u001b[39m\u001b[39m'\u001b[39m)\n\u001b[0;32m----> <a href='vscode-notebook-cell://ssh-remote%2Bcatams4/home/xumin/workspace/Bert-VITS2/text/es_phonemizer/test.ipynb#W0sdnNjb2RlLXJlbW90ZQ%3D%3D?line=4'>5</a>\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mes_to_ipa\u001b[39;00m \u001b[39mimport\u001b[39;00m es2ipa\n\u001b[1;32m      <a href='vscode-notebook-cell://ssh-remote%2Bcatams4/home/xumin/workspace/Bert-VITS2/text/es_phonemizer/test.ipynb#W0sdnNjb2RlLXJlbW90ZQ%3D%3D?line=8'>9</a>\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39msplit_sentences_en\u001b[39m(text, min_len\u001b[39m=\u001b[39m\u001b[39m10\u001b[39m):\n\u001b[1;32m     <a href='vscode-notebook-cell://ssh-remote%2Bcatams4/home/xumin/workspace/Bert-VITS2/text/es_phonemizer/test.ipynb#W0sdnNjb2RlLXJlbW90ZQ%3D%3D?line=9'>10</a>\u001b[0m   \u001b[39m# 将文本中的换行符、空格和制表符替换为空格\u001b[39;00m\n\u001b[1;32m     <a href='vscode-notebook-cell://ssh-remote%2Bcatams4/home/xumin/workspace/Bert-VITS2/text/es_phonemizer/test.ipynb#W0sdnNjb2RlLXJlbW90ZQ%3D%3D?line=10'>11</a>\u001b[0m   text \u001b[39m=\u001b[39m re\u001b[39m.\u001b[39msub(\u001b[39m'\u001b[39m\u001b[39m[\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m\\t\u001b[39;00m\u001b[39m ]+\u001b[39m\u001b[39m'\u001b[39m, \u001b[39m'\u001b[39m\u001b[39m \u001b[39m\u001b[39m'\u001b[39m, text)\n",
         | 
| 16 | 
            +
                  "File \u001b[0;32m/data/workspace/Bert-VITS2/text/es_phonemizer/es_to_ipa.py:1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39m.\u001b[39;00m\u001b[39mcleaner\u001b[39;00m \u001b[39mimport\u001b[39;00m spanish_cleaners\n\u001b[1;32m      2\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39m.\u001b[39;00m\u001b[39mgruut_wrapper\u001b[39;00m \u001b[39mimport\u001b[39;00m Gruut\n\u001b[1;32m      4\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mes2ipa\u001b[39m(text):\n",
         | 
| 17 | 
            +
                  "\u001b[0;31mImportError\u001b[0m: attempted relative import with no known parent package"
         | 
| 18 | 
            +
                 ]
         | 
| 19 | 
            +
                }
         | 
| 20 | 
            +
               ],
         | 
| 21 | 
            +
               "source": [
         | 
| 22 | 
            +
                "import re\n",
         | 
| 23 | 
            +
                "import os\n",
         | 
| 24 | 
            +
                "import os, sys\n",
         | 
| 25 | 
            +
                "sys.path.append('/home/xumin/workspace/MyShell-VC-Training/text/es_phonemizer/')\n",
         | 
| 26 | 
            +
                "from es_to_ipa import es2ipa\n",
         | 
| 27 | 
            +
                "\n",
         | 
| 28 | 
            +
                "\n",
         | 
| 29 | 
            +
                "\n",
         | 
| 30 | 
            +
                "def split_sentences_en(text, min_len=10):\n",
         | 
| 31 | 
            +
                "  # 将文本中的换行符、空格和制表符替换为空格\n",
         | 
| 32 | 
            +
                "  text = re.sub('[\\n\\t ]+', ' ', text)\n",
         | 
| 33 | 
            +
                "  # 在标点符号后添加一个空格\n",
         | 
| 34 | 
            +
                "  text = re.sub('([¿—¡])', r'\\1 $#!', text)\n",
         | 
| 35 | 
            +
                "  # 分隔句子并去除前后空格\n",
         | 
| 36 | 
            +
                "  \n",
         | 
| 37 | 
            +
                "  sentences = [s.strip() for s in text.split(' $#!')]\n",
         | 
| 38 | 
            +
                "  if len(sentences[-1]) == 0: del sentences[-1]\n",
         | 
| 39 | 
            +
                "\n",
         | 
| 40 | 
            +
                "  new_sentences = []\n",
         | 
| 41 | 
            +
                "  new_sent = []\n",
         | 
| 42 | 
            +
                "  for ind, sent in enumerate(sentences):\n",
         | 
| 43 | 
            +
                "    if sent in ['¿', '—', '¡']:\n",
         | 
| 44 | 
            +
                "      new_sent.append(sent)\n",
         | 
| 45 | 
            +
                "    else:\n",
         | 
| 46 | 
            +
                "      new_sent.append(es2ipa(sent))\n",
         | 
| 47 | 
            +
                "    \n",
         | 
| 48 | 
            +
                "  \n",
         | 
| 49 | 
            +
                "  new_sentences = ''.join(new_sent)\n",
         | 
| 50 | 
            +
                "\n",
         | 
| 51 | 
            +
                "  return new_sentences"
         | 
| 52 | 
            +
               ]
         | 
| 53 | 
            +
              },
         | 
| 54 | 
            +
              {
         | 
| 55 | 
            +
               "cell_type": "code",
         | 
| 56 | 
            +
               "execution_count": 3,
         | 
| 57 | 
            +
               "metadata": {},
         | 
| 58 | 
            +
               "outputs": [
         | 
| 59 | 
            +
                {
         | 
| 60 | 
            +
                 "data": {
         | 
| 61 | 
            +
                  "text/plain": [
         | 
| 62 | 
            +
                   "'—¿aβˈeis estˈaðo kasˈaða alɣˈuna bˈeθ?'"
         | 
| 63 | 
            +
                  ]
         | 
| 64 | 
            +
                 },
         | 
| 65 | 
            +
                 "execution_count": 3,
         | 
| 66 | 
            +
                 "metadata": {},
         | 
| 67 | 
            +
                 "output_type": "execute_result"
         | 
| 68 | 
            +
                }
         | 
| 69 | 
            +
               ],
         | 
| 70 | 
            +
               "source": [
         | 
| 71 | 
            +
                "split_sentences_en('—¿Habéis estado casada alguna vez?')"
         | 
| 72 | 
            +
               ]
         | 
| 73 | 
            +
              },
         | 
| 74 | 
            +
              {
         | 
| 75 | 
            +
               "cell_type": "code",
         | 
| 76 | 
            +
               "execution_count": 4,
         | 
| 77 | 
            +
               "metadata": {},
         | 
| 78 | 
            +
               "outputs": [
         | 
| 79 | 
            +
                {
         | 
| 80 | 
            +
                 "data": {
         | 
| 81 | 
            +
                  "text/plain": [
         | 
| 82 | 
            +
                   "'aβˈeis estˈaðo kasˈaða alɣˈuna bˈeθ?'"
         | 
| 83 | 
            +
                  ]
         | 
| 84 | 
            +
                 },
         | 
| 85 | 
            +
                 "execution_count": 4,
         | 
| 86 | 
            +
                 "metadata": {},
         | 
| 87 | 
            +
                 "output_type": "execute_result"
         | 
| 88 | 
            +
                }
         | 
| 89 | 
            +
               ],
         | 
| 90 | 
            +
               "source": [
         | 
| 91 | 
            +
                "es2ipa('—¿Habéis estado casada alguna vez?')"
         | 
| 92 | 
            +
               ]
         | 
| 93 | 
            +
              },
         | 
| 94 | 
            +
              {
         | 
| 95 | 
            +
               "cell_type": "code",
         | 
| 96 | 
            +
               "execution_count": null,
         | 
| 97 | 
            +
               "metadata": {},
         | 
| 98 | 
            +
               "outputs": [],
         | 
| 99 | 
            +
               "source": []
         | 
| 100 | 
            +
              }
         | 
| 101 | 
            +
             ],
         | 
| 102 | 
            +
             "metadata": {
         | 
| 103 | 
            +
              "kernelspec": {
         | 
| 104 | 
            +
               "display_name": "base",
         | 
| 105 | 
            +
               "language": "python",
         | 
| 106 | 
            +
               "name": "python3"
         | 
| 107 | 
            +
              },
         | 
| 108 | 
            +
              "language_info": {
         | 
| 109 | 
            +
               "codemirror_mode": {
         | 
| 110 | 
            +
                "name": "ipython",
         | 
| 111 | 
            +
                "version": 3
         | 
| 112 | 
            +
               },
         | 
| 113 | 
            +
               "file_extension": ".py",
         | 
| 114 | 
            +
               "mimetype": "text/x-python",
         | 
| 115 | 
            +
               "name": "python",
         | 
| 116 | 
            +
               "nbconvert_exporter": "python",
         | 
| 117 | 
            +
               "pygments_lexer": "ipython3",
         | 
| 118 | 
            +
               "version": "3.8.18"
         | 
| 119 | 
            +
              },
         | 
| 120 | 
            +
              "orig_nbformat": 4
         | 
| 121 | 
            +
             },
         | 
| 122 | 
            +
             "nbformat": 4,
         | 
| 123 | 
            +
             "nbformat_minor": 2
         | 
| 124 | 
            +
            }
         | 
    	
        melo/text/fr_phonemizer/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        melo/text/fr_phonemizer/base.py
    ADDED
    
    | @@ -0,0 +1,140 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import abc
         | 
| 2 | 
            +
            from typing import List, Tuple
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from .punctuation import Punctuation
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            class BasePhonemizer(abc.ABC):
         | 
| 8 | 
            +
                """Base phonemizer class
         | 
| 9 | 
            +
             | 
| 10 | 
            +
                Phonemization follows the following steps:
         | 
| 11 | 
            +
                    1. Preprocessing:
         | 
| 12 | 
            +
                        - remove empty lines
         | 
| 13 | 
            +
                        - remove punctuation
         | 
| 14 | 
            +
                        - keep track of punctuation marks
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                    2. Phonemization:
         | 
| 17 | 
            +
                        - convert text to phonemes
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                    3. Postprocessing:
         | 
| 20 | 
            +
                        - join phonemes
         | 
| 21 | 
            +
                        - restore punctuation marks
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                Args:
         | 
| 24 | 
            +
                    language (str):
         | 
| 25 | 
            +
                        Language used by the phonemizer.
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                    punctuations (List[str]):
         | 
| 28 | 
            +
                        List of punctuation marks to be preserved.
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                    keep_puncs (bool):
         | 
| 31 | 
            +
                        Whether to preserve punctuation marks or not.
         | 
| 32 | 
            +
                """
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                def __init__(self, language, punctuations=Punctuation.default_puncs(), keep_puncs=False):
         | 
| 35 | 
            +
                    # ensure the backend is installed on the system
         | 
| 36 | 
            +
                    if not self.is_available():
         | 
| 37 | 
            +
                        raise RuntimeError("{} not installed on your system".format(self.name()))  # pragma: nocover
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                    # ensure the backend support the requested language
         | 
| 40 | 
            +
                    self._language = self._init_language(language)
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                    # setup punctuation processing
         | 
| 43 | 
            +
                    self._keep_puncs = keep_puncs
         | 
| 44 | 
            +
                    self._punctuator = Punctuation(punctuations)
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                def _init_language(self, language):
         | 
| 47 | 
            +
                    """Language initialization
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                    This method may be overloaded in child classes (see Segments backend)
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                    """
         | 
| 52 | 
            +
                    if not self.is_supported_language(language):
         | 
| 53 | 
            +
                        raise RuntimeError(f'language "{language}" is not supported by the ' f"{self.name()} backend")
         | 
| 54 | 
            +
                    return language
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                @property
         | 
| 57 | 
            +
                def language(self):
         | 
| 58 | 
            +
                    """The language code configured to be used for phonemization"""
         | 
| 59 | 
            +
                    return self._language
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                @staticmethod
         | 
| 62 | 
            +
                @abc.abstractmethod
         | 
| 63 | 
            +
                def name():
         | 
| 64 | 
            +
                    """The name of the backend"""
         | 
| 65 | 
            +
                    ...
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                @classmethod
         | 
| 68 | 
            +
                @abc.abstractmethod
         | 
| 69 | 
            +
                def is_available(cls):
         | 
| 70 | 
            +
                    """Returns True if the backend is installed, False otherwise"""
         | 
| 71 | 
            +
                    ...
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                @classmethod
         | 
| 74 | 
            +
                @abc.abstractmethod
         | 
| 75 | 
            +
                def version(cls):
         | 
| 76 | 
            +
                    """Return the backend version as a tuple (major, minor, patch)"""
         | 
| 77 | 
            +
                    ...
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                @staticmethod
         | 
| 80 | 
            +
                @abc.abstractmethod
         | 
| 81 | 
            +
                def supported_languages():
         | 
| 82 | 
            +
                    """Return a dict of language codes -> name supported by the backend"""
         | 
| 83 | 
            +
                    ...
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                def is_supported_language(self, language):
         | 
| 86 | 
            +
                    """Returns True if `language` is supported by the backend"""
         | 
| 87 | 
            +
                    return language in self.supported_languages()
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                @abc.abstractmethod
         | 
| 90 | 
            +
                def _phonemize(self, text, separator):
         | 
| 91 | 
            +
                    """The main phonemization method"""
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                def _phonemize_preprocess(self, text) -> Tuple[List[str], List]:
         | 
| 94 | 
            +
                    """Preprocess the text before phonemization
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                    1. remove spaces
         | 
| 97 | 
            +
                    2. remove punctuation
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                    Override this if you need a different behaviour
         | 
| 100 | 
            +
                    """
         | 
| 101 | 
            +
                    text = text.strip()
         | 
| 102 | 
            +
                    if self._keep_puncs:
         | 
| 103 | 
            +
                        # a tuple (text, punctuation marks)
         | 
| 104 | 
            +
                        return self._punctuator.strip_to_restore(text)
         | 
| 105 | 
            +
                    return [self._punctuator.strip(text)], []
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                def _phonemize_postprocess(self, phonemized, punctuations) -> str:
         | 
| 108 | 
            +
                    """Postprocess the raw phonemized output
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                    Override this if you need a different behaviour
         | 
| 111 | 
            +
                    """
         | 
| 112 | 
            +
                    if self._keep_puncs:
         | 
| 113 | 
            +
                        return self._punctuator.restore(phonemized, punctuations)[0]
         | 
| 114 | 
            +
                    return phonemized[0]
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                def phonemize(self, text: str, separator="|", language: str = None) -> str:  # pylint: disable=unused-argument
         | 
| 117 | 
            +
                    """Returns the `text` phonemized for the given language
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                    Args:
         | 
| 120 | 
            +
                        text (str):
         | 
| 121 | 
            +
                            Text to be phonemized.
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                        separator (str):
         | 
| 124 | 
            +
                            string separator used between phonemes. Default to '_'.
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                    Returns:
         | 
| 127 | 
            +
                        (str): Phonemized text
         | 
| 128 | 
            +
                    """
         | 
| 129 | 
            +
                    text, punctuations = self._phonemize_preprocess(text)
         | 
| 130 | 
            +
                    phonemized = []
         | 
| 131 | 
            +
                    for t in text:
         | 
| 132 | 
            +
                        p = self._phonemize(t, separator)
         | 
| 133 | 
            +
                        phonemized.append(p)
         | 
| 134 | 
            +
                    phonemized = self._phonemize_postprocess(phonemized, punctuations)
         | 
| 135 | 
            +
                    return phonemized
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                def print_logs(self, level: int = 0):
         | 
| 138 | 
            +
                    indent = "\t" * level
         | 
| 139 | 
            +
                    print(f"{indent}| > phoneme language: {self.language}")
         | 
| 140 | 
            +
                    print(f"{indent}| > phoneme backend: {self.name()}")
         | 
    	
        melo/text/fr_phonemizer/cleaner.py
    ADDED
    
    | @@ -0,0 +1,122 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """Set of default text cleaners"""
         | 
| 2 | 
            +
            # TODO: pick the cleaner for languages dynamically
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import re
         | 
| 5 | 
            +
            from .french_abbreviations import abbreviations_fr
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            # Regular expression matching whitespace:
         | 
| 8 | 
            +
            _whitespace_re = re.compile(r"\s+")
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            rep_map = {
         | 
| 12 | 
            +
                ":": ",",
         | 
| 13 | 
            +
                ";": ",",
         | 
| 14 | 
            +
                ",": ",",
         | 
| 15 | 
            +
                "。": ".",
         | 
| 16 | 
            +
                "!": "!",
         | 
| 17 | 
            +
                "?": "?",
         | 
| 18 | 
            +
                "\n": ".",
         | 
| 19 | 
            +
                "·": ",",
         | 
| 20 | 
            +
                "、": ",",
         | 
| 21 | 
            +
                "...": ".",
         | 
| 22 | 
            +
                "…": ".",
         | 
| 23 | 
            +
                "$": ".",
         | 
| 24 | 
            +
                "“": "",
         | 
| 25 | 
            +
                "”": "",
         | 
| 26 | 
            +
                "‘": "",
         | 
| 27 | 
            +
                "’": "",
         | 
| 28 | 
            +
                "(": "",
         | 
| 29 | 
            +
                ")": "",
         | 
| 30 | 
            +
                "(": "",
         | 
| 31 | 
            +
                ")": "",
         | 
| 32 | 
            +
                "《": "",
         | 
| 33 | 
            +
                "》": "",
         | 
| 34 | 
            +
                "【": "",
         | 
| 35 | 
            +
                "】": "",
         | 
| 36 | 
            +
                "[": "",
         | 
| 37 | 
            +
                "]": "",
         | 
| 38 | 
            +
                "—": "",
         | 
| 39 | 
            +
                "~": "-",
         | 
| 40 | 
            +
                "~": "-",
         | 
| 41 | 
            +
                "「": "",
         | 
| 42 | 
            +
                "」": "",
         | 
| 43 | 
            +
                "¿" : "",
         | 
| 44 | 
            +
                "¡" : ""
         | 
| 45 | 
            +
            }
         | 
| 46 | 
            +
             | 
| 47 | 
            +
             | 
| 48 | 
            +
            def replace_punctuation(text):
         | 
| 49 | 
            +
                pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
         | 
| 50 | 
            +
                replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
         | 
| 51 | 
            +
                return replaced_text
         | 
| 52 | 
            +
             | 
| 53 | 
            +
            def expand_abbreviations(text, lang="fr"):
         | 
| 54 | 
            +
                if lang == "fr":
         | 
| 55 | 
            +
                    _abbreviations = abbreviations_fr
         | 
| 56 | 
            +
                for regex, replacement in _abbreviations:
         | 
| 57 | 
            +
                    text = re.sub(regex, replacement, text)
         | 
| 58 | 
            +
                return text
         | 
| 59 | 
            +
             | 
| 60 | 
            +
             | 
| 61 | 
            +
            def lowercase(text):
         | 
| 62 | 
            +
                return text.lower()
         | 
| 63 | 
            +
             | 
| 64 | 
            +
             | 
| 65 | 
            +
            def collapse_whitespace(text):
         | 
| 66 | 
            +
                return re.sub(_whitespace_re, " ", text).strip()
         | 
| 67 | 
            +
             | 
| 68 | 
            +
            def remove_punctuation_at_begin(text):
         | 
| 69 | 
            +
                return re.sub(r'^[,.!?]+', '', text)
         | 
| 70 | 
            +
             | 
| 71 | 
            +
            def remove_aux_symbols(text):
         | 
| 72 | 
            +
                text = re.sub(r"[\<\>\(\)\[\]\"\«\»]+", "", text)
         | 
| 73 | 
            +
                return text
         | 
| 74 | 
            +
             | 
| 75 | 
            +
             | 
| 76 | 
            +
            def replace_symbols(text, lang="en"):
         | 
| 77 | 
            +
                """Replace symbols based on the lenguage tag.
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                Args:
         | 
| 80 | 
            +
                  text:
         | 
| 81 | 
            +
                   Input text.
         | 
| 82 | 
            +
                  lang:
         | 
| 83 | 
            +
                    Lenguage identifier. ex: "en", "fr", "pt", "ca".
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                Returns:
         | 
| 86 | 
            +
                  The modified text
         | 
| 87 | 
            +
                  example:
         | 
| 88 | 
            +
                    input args:
         | 
| 89 | 
            +
                        text: "si l'avi cau, diguem-ho"
         | 
| 90 | 
            +
                        lang: "ca"
         | 
| 91 | 
            +
                    Output:
         | 
| 92 | 
            +
                        text: "si lavi cau, diguemho"
         | 
| 93 | 
            +
                """
         | 
| 94 | 
            +
                text = text.replace(";", ",")
         | 
| 95 | 
            +
                text = text.replace("-", " ") if lang != "ca" else text.replace("-", "")
         | 
| 96 | 
            +
                text = text.replace(":", ",")
         | 
| 97 | 
            +
                if lang == "en":
         | 
| 98 | 
            +
                    text = text.replace("&", " and ")
         | 
| 99 | 
            +
                elif lang == "fr":
         | 
| 100 | 
            +
                    text = text.replace("&", " et ")
         | 
| 101 | 
            +
                elif lang == "pt":
         | 
| 102 | 
            +
                    text = text.replace("&", " e ")
         | 
| 103 | 
            +
                elif lang == "ca":
         | 
| 104 | 
            +
                    text = text.replace("&", " i ")
         | 
| 105 | 
            +
                    text = text.replace("'", "")
         | 
| 106 | 
            +
                elif lang== "es":
         | 
| 107 | 
            +
                    text=text.replace("&","y")
         | 
| 108 | 
            +
                    text = text.replace("'", "")
         | 
| 109 | 
            +
                return text
         | 
| 110 | 
            +
             | 
| 111 | 
            +
            def french_cleaners(text):
         | 
| 112 | 
            +
                """Pipeline for French text. There is no need to expand numbers, phonemizer already does that"""
         | 
| 113 | 
            +
                text = expand_abbreviations(text, lang="fr")
         | 
| 114 | 
            +
                # text = lowercase(text) # as we use the cased bert
         | 
| 115 | 
            +
                text = replace_punctuation(text)
         | 
| 116 | 
            +
                text = replace_symbols(text, lang="fr")
         | 
| 117 | 
            +
                text = remove_aux_symbols(text)
         | 
| 118 | 
            +
                text = remove_punctuation_at_begin(text)
         | 
| 119 | 
            +
                text = collapse_whitespace(text)
         | 
| 120 | 
            +
                text = re.sub(r'([^\.,!\?\-…])$', r'\1.', text)
         | 
| 121 | 
            +
                return text
         | 
| 122 | 
            +
             | 
    	
        melo/text/fr_phonemizer/en_symbols.json
    ADDED
    
    | @@ -0,0 +1,78 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {"symbols": [
         | 
| 2 | 
            +
                "_",
         | 
| 3 | 
            +
                ",",
         | 
| 4 | 
            +
                ".",
         | 
| 5 | 
            +
                "!",
         | 
| 6 | 
            +
                "?",
         | 
| 7 | 
            +
                "-",
         | 
| 8 | 
            +
                "~",
         | 
| 9 | 
            +
                "\u2026",
         | 
| 10 | 
            +
                "N",
         | 
| 11 | 
            +
                "Q",
         | 
| 12 | 
            +
                "a",
         | 
| 13 | 
            +
                "b",
         | 
| 14 | 
            +
                "d",
         | 
| 15 | 
            +
                "e",
         | 
| 16 | 
            +
                "f",
         | 
| 17 | 
            +
                "g",
         | 
| 18 | 
            +
                "h",
         | 
| 19 | 
            +
                "i",
         | 
| 20 | 
            +
                "j",
         | 
| 21 | 
            +
                "k",
         | 
| 22 | 
            +
                "l",
         | 
| 23 | 
            +
                "m",
         | 
| 24 | 
            +
                "n",
         | 
| 25 | 
            +
                "o",
         | 
| 26 | 
            +
                "p",
         | 
| 27 | 
            +
                "s",
         | 
| 28 | 
            +
                "t",
         | 
| 29 | 
            +
                "u",
         | 
| 30 | 
            +
                "v",
         | 
| 31 | 
            +
                "w",
         | 
| 32 | 
            +
                "x",
         | 
| 33 | 
            +
                "y",
         | 
| 34 | 
            +
                "z",
         | 
| 35 | 
            +
                "\u0251",
         | 
| 36 | 
            +
                "\u00e6",
         | 
| 37 | 
            +
                "\u0283",
         | 
| 38 | 
            +
                "\u0291",
         | 
| 39 | 
            +
                "\u00e7",
         | 
| 40 | 
            +
                "\u026f",
         | 
| 41 | 
            +
                "\u026a",
         | 
| 42 | 
            +
                "\u0254",
         | 
| 43 | 
            +
                "\u025b",
         | 
| 44 | 
            +
                "\u0279",
         | 
| 45 | 
            +
                "\u00f0",
         | 
| 46 | 
            +
                "\u0259",
         | 
| 47 | 
            +
                "\u026b",
         | 
| 48 | 
            +
                "\u0265",
         | 
| 49 | 
            +
                "\u0278",
         | 
| 50 | 
            +
                "\u028a",
         | 
| 51 | 
            +
                "\u027e",
         | 
| 52 | 
            +
                "\u0292",
         | 
| 53 | 
            +
                "\u03b8",
         | 
| 54 | 
            +
                "\u03b2",
         | 
| 55 | 
            +
                "\u014b",
         | 
| 56 | 
            +
                "\u0266",
         | 
| 57 | 
            +
                "\u207c",
         | 
| 58 | 
            +
                "\u02b0",
         | 
| 59 | 
            +
                "`",
         | 
| 60 | 
            +
                "^",
         | 
| 61 | 
            +
                "#",
         | 
| 62 | 
            +
                "*",
         | 
| 63 | 
            +
                "=",
         | 
| 64 | 
            +
                "\u02c8",
         | 
| 65 | 
            +
                "\u02cc",
         | 
| 66 | 
            +
                "\u2192",
         | 
| 67 | 
            +
                "\u2193",
         | 
| 68 | 
            +
                "\u2191",
         | 
| 69 | 
            +
                " ",
         | 
| 70 | 
            +
                "ɣ",
         | 
| 71 | 
            +
                "ɡ", 
         | 
| 72 | 
            +
                "r", 
         | 
| 73 | 
            +
                "ɲ", 
         | 
| 74 | 
            +
                "ʝ", 
         | 
| 75 | 
            +
                "ʎ",
         | 
| 76 | 
            +
                "ː"
         | 
| 77 | 
            +
              ]
         | 
| 78 | 
            +
            }
         | 
    	
        melo/text/fr_phonemizer/fr_symbols.json
    ADDED
    
    | @@ -0,0 +1,89 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
                "symbols": [
         | 
| 3 | 
            +
                    "_",
         | 
| 4 | 
            +
                    ",",
         | 
| 5 | 
            +
                    ".",
         | 
| 6 | 
            +
                    "!",
         | 
| 7 | 
            +
                    "?",
         | 
| 8 | 
            +
                    "-",
         | 
| 9 | 
            +
                    "~",
         | 
| 10 | 
            +
                    "\u2026",
         | 
| 11 | 
            +
                    "N",
         | 
| 12 | 
            +
                    "Q",
         | 
| 13 | 
            +
                    "a",
         | 
| 14 | 
            +
                    "b",
         | 
| 15 | 
            +
                    "d",
         | 
| 16 | 
            +
                    "e",
         | 
| 17 | 
            +
                    "f",
         | 
| 18 | 
            +
                    "g",
         | 
| 19 | 
            +
                    "h",
         | 
| 20 | 
            +
                    "i",
         | 
| 21 | 
            +
                    "j",
         | 
| 22 | 
            +
                    "k",
         | 
| 23 | 
            +
                    "l",
         | 
| 24 | 
            +
                    "m",
         | 
| 25 | 
            +
                    "n",
         | 
| 26 | 
            +
                    "o",
         | 
| 27 | 
            +
                    "p",
         | 
| 28 | 
            +
                    "s",
         | 
| 29 | 
            +
                    "t",
         | 
| 30 | 
            +
                    "u",
         | 
| 31 | 
            +
                    "v",
         | 
| 32 | 
            +
                    "w",
         | 
| 33 | 
            +
                    "x",
         | 
| 34 | 
            +
                    "y",
         | 
| 35 | 
            +
                    "z",
         | 
| 36 | 
            +
                    "\u0251",
         | 
| 37 | 
            +
                    "\u00e6",
         | 
| 38 | 
            +
                    "\u0283",
         | 
| 39 | 
            +
                    "\u0291",
         | 
| 40 | 
            +
                    "\u00e7",
         | 
| 41 | 
            +
                    "\u026f",
         | 
| 42 | 
            +
                    "\u026a",
         | 
| 43 | 
            +
                    "\u0254",
         | 
| 44 | 
            +
                    "\u025b",
         | 
| 45 | 
            +
                    "\u0279",
         | 
| 46 | 
            +
                    "\u00f0",
         | 
| 47 | 
            +
                    "\u0259",
         | 
| 48 | 
            +
                    "\u026b",
         | 
| 49 | 
            +
                    "\u0265",
         | 
| 50 | 
            +
                    "\u0278",
         | 
| 51 | 
            +
                    "\u028a",
         | 
| 52 | 
            +
                    "\u027e",
         | 
| 53 | 
            +
                    "\u0292",
         | 
| 54 | 
            +
                    "\u03b8",
         | 
| 55 | 
            +
                    "\u03b2",
         | 
| 56 | 
            +
                    "\u014b",
         | 
| 57 | 
            +
                    "\u0266",
         | 
| 58 | 
            +
                    "\u207c",
         | 
| 59 | 
            +
                    "\u02b0",
         | 
| 60 | 
            +
                    "`",
         | 
| 61 | 
            +
                    "^",
         | 
| 62 | 
            +
                    "#",
         | 
| 63 | 
            +
                    "*",
         | 
| 64 | 
            +
                    "=",
         | 
| 65 | 
            +
                    "\u02c8",
         | 
| 66 | 
            +
                    "\u02cc",
         | 
| 67 | 
            +
                    "\u2192",
         | 
| 68 | 
            +
                    "\u2193",
         | 
| 69 | 
            +
                    "\u2191",
         | 
| 70 | 
            +
                    " ",
         | 
| 71 | 
            +
                    "\u0263",
         | 
| 72 | 
            +
                    "\u0261",
         | 
| 73 | 
            +
                    "r",
         | 
| 74 | 
            +
                    "\u0272",
         | 
| 75 | 
            +
                    "\u029d",
         | 
| 76 | 
            +
                    "\u028e",
         | 
| 77 | 
            +
                    "\u02d0",
         | 
| 78 | 
            +
                    
         | 
| 79 | 
            +
                    "\u0303",
         | 
| 80 | 
            +
                    "\u0153",
         | 
| 81 | 
            +
                    "\u00f8",
         | 
| 82 | 
            +
                    "\u0281",
         | 
| 83 | 
            +
                    "\u0252",
         | 
| 84 | 
            +
                    "\u028c",
         | 
| 85 | 
            +
                    "\u2014",
         | 
| 86 | 
            +
                    "\u025c",
         | 
| 87 | 
            +
                    "\u0250"
         | 
| 88 | 
            +
                ]
         | 
| 89 | 
            +
            }
         | 
    	
        melo/text/fr_phonemizer/fr_to_ipa.py
    ADDED
    
    | @@ -0,0 +1,30 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .cleaner import french_cleaners
         | 
| 2 | 
            +
            from .gruut_wrapper import Gruut
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            def remove_consecutive_t(input_str):
         | 
| 6 | 
            +
                result = []
         | 
| 7 | 
            +
                count = 0
         | 
| 8 | 
            +
             | 
| 9 | 
            +
                for char in input_str:
         | 
| 10 | 
            +
                    if char == 't':
         | 
| 11 | 
            +
                        count += 1
         | 
| 12 | 
            +
                    else:
         | 
| 13 | 
            +
                        if count < 3:  
         | 
| 14 | 
            +
                            result.extend(['t'] * count)
         | 
| 15 | 
            +
                        count = 0
         | 
| 16 | 
            +
                        result.append(char)
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                if count < 3:
         | 
| 19 | 
            +
                    result.extend(['t'] * count)
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                return ''.join(result)
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            def fr2ipa(text):
         | 
| 24 | 
            +
                e = Gruut(language="fr-fr", keep_puncs=True, keep_stress=True, use_espeak_phonemes=True)
         | 
| 25 | 
            +
                # text = french_cleaners(text)
         | 
| 26 | 
            +
                phonemes = e.phonemize(text, separator="")
         | 
| 27 | 
            +
                # print(phonemes)
         | 
| 28 | 
            +
                phonemes = remove_consecutive_t(phonemes)
         | 
| 29 | 
            +
                # print(phonemes)
         | 
| 30 | 
            +
                return phonemes
         | 
    	
        melo/text/fr_phonemizer/french_abbreviations.py
    ADDED
    
    | @@ -0,0 +1,48 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import re
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            # List of (regular expression, replacement) pairs for abbreviations in french:
         | 
| 4 | 
            +
            abbreviations_fr = [
         | 
| 5 | 
            +
                (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
         | 
| 6 | 
            +
                for x in [
         | 
| 7 | 
            +
                    ("M", "monsieur"),
         | 
| 8 | 
            +
                    ("Mlle", "mademoiselle"),
         | 
| 9 | 
            +
                    ("Mlles", "mesdemoiselles"),
         | 
| 10 | 
            +
                    ("Mme", "Madame"),
         | 
| 11 | 
            +
                    ("Mmes", "Mesdames"),
         | 
| 12 | 
            +
                    ("N.B", "nota bene"),
         | 
| 13 | 
            +
                    ("M", "monsieur"),
         | 
| 14 | 
            +
                    ("p.c.q", "parce que"),
         | 
| 15 | 
            +
                    ("Pr", "professeur"),
         | 
| 16 | 
            +
                    ("qqch", "quelque chose"),
         | 
| 17 | 
            +
                    ("rdv", "rendez-vous"),
         | 
| 18 | 
            +
                    ("max", "maximum"),
         | 
| 19 | 
            +
                    ("min", "minimum"),
         | 
| 20 | 
            +
                    ("no", "numéro"),
         | 
| 21 | 
            +
                    ("adr", "adresse"),
         | 
| 22 | 
            +
                    ("dr", "docteur"),
         | 
| 23 | 
            +
                    ("st", "saint"),
         | 
| 24 | 
            +
                    ("co", "companie"),
         | 
| 25 | 
            +
                    ("jr", "junior"),
         | 
| 26 | 
            +
                    ("sgt", "sergent"),
         | 
| 27 | 
            +
                    ("capt", "capitain"),
         | 
| 28 | 
            +
                    ("col", "colonel"),
         | 
| 29 | 
            +
                    ("av", "avenue"),
         | 
| 30 | 
            +
                    ("av. J.-C", "avant Jésus-Christ"),
         | 
| 31 | 
            +
                    ("apr. J.-C", "après Jésus-Christ"),
         | 
| 32 | 
            +
                    ("art", "article"),
         | 
| 33 | 
            +
                    ("boul", "boulevard"),
         | 
| 34 | 
            +
                    ("c.-à-d", "c’est-à-dire"),
         | 
| 35 | 
            +
                    ("etc", "et cetera"),
         | 
| 36 | 
            +
                    ("ex", "exemple"),
         | 
| 37 | 
            +
                    ("excl", "exclusivement"),
         | 
| 38 | 
            +
                    ("boul", "boulevard"),
         | 
| 39 | 
            +
                ]
         | 
| 40 | 
            +
            ] + [
         | 
| 41 | 
            +
                (re.compile("\\b%s" % x[0]), x[1])
         | 
| 42 | 
            +
                for x in [
         | 
| 43 | 
            +
                    ("Mlle", "mademoiselle"),
         | 
| 44 | 
            +
                    ("Mlles", "mesdemoiselles"),
         | 
| 45 | 
            +
                    ("Mme", "Madame"),
         | 
| 46 | 
            +
                    ("Mmes", "Mesdames"),
         | 
| 47 | 
            +
                ]
         | 
| 48 | 
            +
            ]
         | 
    	
        melo/text/fr_phonemizer/french_symbols.txt
    ADDED
    
    | @@ -0,0 +1 @@ | |
|  | 
|  | |
| 1 | 
            +
            _,.!?-~…NQabdefghijklmnopstuvwxyzɑæʃʑçɯɪɔɛɹðəɫɥɸʊɾʒθβŋɦ⁼ʰ`^#*=ˈˌ→↓↑ ɣɡrɲʝʎː̃œøʁɒʌ—ɜɐ
         | 
    	
        melo/text/fr_phonemizer/gruut_wrapper.py
    ADDED
    
    | @@ -0,0 +1,258 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import importlib
         | 
| 2 | 
            +
            from typing import List
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import gruut
         | 
| 5 | 
            +
            from gruut_ipa import IPA # pip install gruut_ipa
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from .base import BasePhonemizer
         | 
| 8 | 
            +
            from .punctuation import Punctuation
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            # Table for str.translate to fix gruut/TTS phoneme mismatch
         | 
| 11 | 
            +
            GRUUT_TRANS_TABLE = str.maketrans("g", "ɡ")
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            class Gruut(BasePhonemizer):
         | 
| 15 | 
            +
                """Gruut wrapper for G2P
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                Args:
         | 
| 18 | 
            +
                    language (str):
         | 
| 19 | 
            +
                        Valid language code for the used backend.
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                    punctuations (str):
         | 
| 22 | 
            +
                        Characters to be treated as punctuation. Defaults to `Punctuation.default_puncs()`.
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                    keep_puncs (bool):
         | 
| 25 | 
            +
                        If true, keep the punctuations after phonemization. Defaults to True.
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                    use_espeak_phonemes (bool):
         | 
| 28 | 
            +
                        If true, use espeak lexicons instead of default Gruut lexicons. Defaults to False.
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                    keep_stress (bool):
         | 
| 31 | 
            +
                        If true, keep the stress characters after phonemization. Defaults to False.
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                Example:
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                    >>> from TTS.tts.utils.text.phonemizers.gruut_wrapper import Gruut
         | 
| 36 | 
            +
                    >>> phonemizer = Gruut('en-us')
         | 
| 37 | 
            +
                    >>> phonemizer.phonemize("Be a voice, not an! echo?", separator="|")
         | 
| 38 | 
            +
                    'b|i| ə| v|ɔ|ɪ|s, n|ɑ|t| ə|n! ɛ|k|o|ʊ?'
         | 
| 39 | 
            +
                """
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                def __init__(
         | 
| 42 | 
            +
                    self,
         | 
| 43 | 
            +
                    language: str,
         | 
| 44 | 
            +
                    punctuations=Punctuation.default_puncs(),
         | 
| 45 | 
            +
                    keep_puncs=True,
         | 
| 46 | 
            +
                    use_espeak_phonemes=False,
         | 
| 47 | 
            +
                    keep_stress=False,
         | 
| 48 | 
            +
                ):
         | 
| 49 | 
            +
                    super().__init__(language, punctuations=punctuations, keep_puncs=keep_puncs)
         | 
| 50 | 
            +
                    self.use_espeak_phonemes = use_espeak_phonemes
         | 
| 51 | 
            +
                    self.keep_stress = keep_stress
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                @staticmethod
         | 
| 54 | 
            +
                def name():
         | 
| 55 | 
            +
                    return "gruut"
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                def phonemize_gruut(self, text: str, separator: str = "|", tie=False) -> str:  # pylint: disable=unused-argument
         | 
| 58 | 
            +
                    """Convert input text to phonemes.
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                    Gruut phonemizes the given `str` by seperating each phoneme character with `separator`, even for characters
         | 
| 61 | 
            +
                    that constitude a single sound.
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                    It doesn't affect 🐸TTS since it individually converts each character to token IDs.
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                    Examples::
         | 
| 66 | 
            +
                        "hello how are you today?" -> `h|ɛ|l|o|ʊ| h|a|ʊ| ɑ|ɹ| j|u| t|ə|d|e|ɪ`
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    Args:
         | 
| 69 | 
            +
                        text (str):
         | 
| 70 | 
            +
                            Text to be converted to phonemes.
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                        tie (bool, optional) : When True use a '͡' character between
         | 
| 73 | 
            +
                            consecutive characters of a single phoneme. Else separate phoneme
         | 
| 74 | 
            +
                            with '_'. This option requires espeak>=1.49. Default to False.
         | 
| 75 | 
            +
                    """
         | 
| 76 | 
            +
                    ph_list = []
         | 
| 77 | 
            +
                    for sentence in gruut.sentences(text, lang=self.language, espeak=self.use_espeak_phonemes):
         | 
| 78 | 
            +
                        for word in sentence:
         | 
| 79 | 
            +
                            if word.is_break:
         | 
| 80 | 
            +
                                # Use actual character for break phoneme (e.g., comma)
         | 
| 81 | 
            +
                                if ph_list:
         | 
| 82 | 
            +
                                    # Join with previous word
         | 
| 83 | 
            +
                                    ph_list[-1].append(word.text)
         | 
| 84 | 
            +
                                else:
         | 
| 85 | 
            +
                                    # First word is punctuation
         | 
| 86 | 
            +
                                    ph_list.append([word.text])
         | 
| 87 | 
            +
                            elif word.phonemes:
         | 
| 88 | 
            +
                                # Add phonemes for word
         | 
| 89 | 
            +
                                word_phonemes = []
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                                for word_phoneme in word.phonemes:
         | 
| 92 | 
            +
                                    if not self.keep_stress:
         | 
| 93 | 
            +
                                        # Remove primary/secondary stress
         | 
| 94 | 
            +
                                        word_phoneme = IPA.without_stress(word_phoneme)
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                                    word_phoneme = word_phoneme.translate(GRUUT_TRANS_TABLE)
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                                    if word_phoneme:
         | 
| 99 | 
            +
                                        # Flatten phonemes
         | 
| 100 | 
            +
                                        word_phonemes.extend(word_phoneme)
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                                if word_phonemes:
         | 
| 103 | 
            +
                                    ph_list.append(word_phonemes)
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                    ph_words = [separator.join(word_phonemes) for word_phonemes in ph_list]
         | 
| 106 | 
            +
                    ph = f"{separator} ".join(ph_words)
         | 
| 107 | 
            +
                    return ph
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                def _phonemize(self, text, separator):
         | 
| 110 | 
            +
                    return self.phonemize_gruut(text, separator, tie=False)
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                def is_supported_language(self, language):
         | 
| 113 | 
            +
                    """Returns True if `language` is supported by the backend"""
         | 
| 114 | 
            +
                    return gruut.is_language_supported(language)
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                @staticmethod
         | 
| 117 | 
            +
                def supported_languages() -> List:
         | 
| 118 | 
            +
                    """Get a dictionary of supported languages.
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                    Returns:
         | 
| 121 | 
            +
                        List: List of language codes.
         | 
| 122 | 
            +
                    """
         | 
| 123 | 
            +
                    return list(gruut.get_supported_languages())
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                def version(self):
         | 
| 126 | 
            +
                    """Get the version of the used backend.
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                    Returns:
         | 
| 129 | 
            +
                        str: Version of the used backend.
         | 
| 130 | 
            +
                    """
         | 
| 131 | 
            +
                    return gruut.__version__
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                @classmethod
         | 
| 134 | 
            +
                def is_available(cls):
         | 
| 135 | 
            +
                    """Return true if ESpeak is available else false"""
         | 
| 136 | 
            +
                    return importlib.util.find_spec("gruut") is not None
         | 
| 137 | 
            +
             | 
| 138 | 
            +
             | 
| 139 | 
            +
            if __name__ == "__main__":
         | 
| 140 | 
            +
                from cleaner import french_cleaners
         | 
| 141 | 
            +
                import json
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                e = Gruut(language="fr-fr", keep_puncs=True, keep_stress=True, use_espeak_phonemes=True)
         | 
| 144 | 
            +
                symbols = [  # en + sp
         | 
| 145 | 
            +
                    "_",
         | 
| 146 | 
            +
                    ",",
         | 
| 147 | 
            +
                    ".",
         | 
| 148 | 
            +
                    "!",
         | 
| 149 | 
            +
                    "?",
         | 
| 150 | 
            +
                    "-",
         | 
| 151 | 
            +
                    "~",
         | 
| 152 | 
            +
                    "\u2026",
         | 
| 153 | 
            +
                    "N",
         | 
| 154 | 
            +
                    "Q",
         | 
| 155 | 
            +
                    "a",
         | 
| 156 | 
            +
                    "b",
         | 
| 157 | 
            +
                    "d",
         | 
| 158 | 
            +
                    "e",
         | 
| 159 | 
            +
                    "f",
         | 
| 160 | 
            +
                    "g",
         | 
| 161 | 
            +
                    "h",
         | 
| 162 | 
            +
                    "i",
         | 
| 163 | 
            +
                    "j",
         | 
| 164 | 
            +
                    "k",
         | 
| 165 | 
            +
                    "l",
         | 
| 166 | 
            +
                    "m",
         | 
| 167 | 
            +
                    "n",
         | 
| 168 | 
            +
                    "o",
         | 
| 169 | 
            +
                    "p",
         | 
| 170 | 
            +
                    "s",
         | 
| 171 | 
            +
                    "t",
         | 
| 172 | 
            +
                    "u",
         | 
| 173 | 
            +
                    "v",
         | 
| 174 | 
            +
                    "w",
         | 
| 175 | 
            +
                    "x",
         | 
| 176 | 
            +
                    "y",
         | 
| 177 | 
            +
                    "z",
         | 
| 178 | 
            +
                    "\u0251",
         | 
| 179 | 
            +
                    "\u00e6",
         | 
| 180 | 
            +
                    "\u0283",
         | 
| 181 | 
            +
                    "\u0291",
         | 
| 182 | 
            +
                    "\u00e7",
         | 
| 183 | 
            +
                    "\u026f",
         | 
| 184 | 
            +
                    "\u026a",
         | 
| 185 | 
            +
                    "\u0254",
         | 
| 186 | 
            +
                    "\u025b",
         | 
| 187 | 
            +
                    "\u0279",
         | 
| 188 | 
            +
                    "\u00f0",
         | 
| 189 | 
            +
                    "\u0259",
         | 
| 190 | 
            +
                    "\u026b",
         | 
| 191 | 
            +
                    "\u0265",
         | 
| 192 | 
            +
                    "\u0278",
         | 
| 193 | 
            +
                    "\u028a",
         | 
| 194 | 
            +
                    "\u027e",
         | 
| 195 | 
            +
                    "\u0292",
         | 
| 196 | 
            +
                    "\u03b8",
         | 
| 197 | 
            +
                    "\u03b2",
         | 
| 198 | 
            +
                    "\u014b",
         | 
| 199 | 
            +
                    "\u0266",
         | 
| 200 | 
            +
                    "\u207c",
         | 
| 201 | 
            +
                    "\u02b0",
         | 
| 202 | 
            +
                    "`",
         | 
| 203 | 
            +
                    "^",
         | 
| 204 | 
            +
                    "#",
         | 
| 205 | 
            +
                    "*",
         | 
| 206 | 
            +
                    "=",
         | 
| 207 | 
            +
                    "\u02c8",
         | 
| 208 | 
            +
                    "\u02cc",
         | 
| 209 | 
            +
                    "\u2192",
         | 
| 210 | 
            +
                    "\u2193",
         | 
| 211 | 
            +
                    "\u2191",
         | 
| 212 | 
            +
                    " ",
         | 
| 213 | 
            +
                    "ɣ",
         | 
| 214 | 
            +
                    "ɡ", 
         | 
| 215 | 
            +
                    "r", 
         | 
| 216 | 
            +
                    "ɲ", 
         | 
| 217 | 
            +
                    "ʝ", 
         | 
| 218 | 
            +
                    "ʎ",
         | 
| 219 | 
            +
                    "ː"
         | 
| 220 | 
            +
                ]
         | 
| 221 | 
            +
                with open('/home/xumin/workspace/VITS-Training-Multiling/230715_fr/metadata.txt', 'r') as f:
         | 
| 222 | 
            +
                    lines = f.readlines()
         | 
| 223 | 
            +
                
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                used_sym = []
         | 
| 226 | 
            +
                not_existed_sym = []
         | 
| 227 | 
            +
                phonemes = []
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                for line in lines:
         | 
| 230 | 
            +
                    text = line.split('|')[-1].strip()
         | 
| 231 | 
            +
                    text = french_cleaners(text)
         | 
| 232 | 
            +
                    ipa =  e.phonemize(text, separator="")
         | 
| 233 | 
            +
                    phonemes.append(ipa)
         | 
| 234 | 
            +
                    for s in ipa:
         | 
| 235 | 
            +
                        if s not in symbols:
         | 
| 236 | 
            +
                            if s not in not_existed_sym:
         | 
| 237 | 
            +
                                print(f'not_existed char: {s}')
         | 
| 238 | 
            +
                                not_existed_sym.append(s)
         | 
| 239 | 
            +
                        else:
         | 
| 240 | 
            +
                            if s not in used_sym:
         | 
| 241 | 
            +
                                # print(f'used char: {s}')
         | 
| 242 | 
            +
                                used_sym.append(s)
         | 
| 243 | 
            +
                
         | 
| 244 | 
            +
                print(used_sym)
         | 
| 245 | 
            +
                print(not_existed_sym)
         | 
| 246 | 
            +
             | 
| 247 | 
            +
             | 
| 248 | 
            +
                with open('./text/fr_phonemizer/french_symbols.txt', 'w') as g:
         | 
| 249 | 
            +
                    g.writelines(symbols + not_existed_sym)
         | 
| 250 | 
            +
                    
         | 
| 251 | 
            +
                with open('./text/fr_phonemizer/example_ipa.txt', 'w') as g:
         | 
| 252 | 
            +
                    g.writelines(phonemes)
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                data = {'symbols': symbols + not_existed_sym}
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                with open('./text/fr_phonemizer/fr_symbols.json', 'w') as f:
         | 
| 257 | 
            +
                    json.dump(data, f, indent=4)
         | 
| 258 | 
            +
             | 
    	
        melo/text/fr_phonemizer/punctuation.py
    ADDED
    
    | @@ -0,0 +1,172 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import collections
         | 
| 2 | 
            +
            import re
         | 
| 3 | 
            +
            from enum import Enum
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import six
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            _DEF_PUNCS = ';:,.!?¡¿—…"«»“”'
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            _PUNC_IDX = collections.namedtuple("_punc_index", ["punc", "position"])
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            class PuncPosition(Enum):
         | 
| 13 | 
            +
                """Enum for the punctuations positions"""
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                BEGIN = 0
         | 
| 16 | 
            +
                END = 1
         | 
| 17 | 
            +
                MIDDLE = 2
         | 
| 18 | 
            +
                ALONE = 3
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            class Punctuation:
         | 
| 22 | 
            +
                """Handle punctuations in text.
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                Just strip punctuations from text or strip and restore them later.
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                Args:
         | 
| 27 | 
            +
                    puncs (str): The punctuations to be processed. Defaults to `_DEF_PUNCS`.
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                Example:
         | 
| 30 | 
            +
                    >>> punc = Punctuation()
         | 
| 31 | 
            +
                    >>> punc.strip("This is. example !")
         | 
| 32 | 
            +
                    'This is example'
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                    >>> text_striped, punc_map = punc.strip_to_restore("This is. example !")
         | 
| 35 | 
            +
                    >>> ' '.join(text_striped)
         | 
| 36 | 
            +
                    'This is example'
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                    >>> text_restored = punc.restore(text_striped, punc_map)
         | 
| 39 | 
            +
                    >>> text_restored[0]
         | 
| 40 | 
            +
                    'This is. example !'
         | 
| 41 | 
            +
                """
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                def __init__(self, puncs: str = _DEF_PUNCS):
         | 
| 44 | 
            +
                    self.puncs = puncs
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                @staticmethod
         | 
| 47 | 
            +
                def default_puncs():
         | 
| 48 | 
            +
                    """Return default set of punctuations."""
         | 
| 49 | 
            +
                    return _DEF_PUNCS
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                @property
         | 
| 52 | 
            +
                def puncs(self):
         | 
| 53 | 
            +
                    return self._puncs
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                @puncs.setter
         | 
| 56 | 
            +
                def puncs(self, value):
         | 
| 57 | 
            +
                    if not isinstance(value, six.string_types):
         | 
| 58 | 
            +
                        raise ValueError("[!] Punctuations must be of type str.")
         | 
| 59 | 
            +
                    self._puncs = "".join(list(dict.fromkeys(list(value))))  # remove duplicates without changing the oreder
         | 
| 60 | 
            +
                    self.puncs_regular_exp = re.compile(rf"(\s*[{re.escape(self._puncs)}]+\s*)+")
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                def strip(self, text):
         | 
| 63 | 
            +
                    """Remove all the punctuations by replacing with `space`.
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                    Args:
         | 
| 66 | 
            +
                        text (str): The text to be processed.
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    Example::
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                        "This is. example !" -> "This is example "
         | 
| 71 | 
            +
                    """
         | 
| 72 | 
            +
                    return re.sub(self.puncs_regular_exp, " ", text).rstrip().lstrip()
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                def strip_to_restore(self, text):
         | 
| 75 | 
            +
                    """Remove punctuations from text to restore them later.
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                    Args:
         | 
| 78 | 
            +
                        text (str): The text to be processed.
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    Examples ::
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                        "This is. example !" -> [["This is", "example"], [".", "!"]]
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                    """
         | 
| 85 | 
            +
                    text, puncs = self._strip_to_restore(text)
         | 
| 86 | 
            +
                    return text, puncs
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                def _strip_to_restore(self, text):
         | 
| 89 | 
            +
                    """Auxiliary method for Punctuation.preserve()"""
         | 
| 90 | 
            +
                    matches = list(re.finditer(self.puncs_regular_exp, text))
         | 
| 91 | 
            +
                    if not matches:
         | 
| 92 | 
            +
                        return [text], []
         | 
| 93 | 
            +
                    # the text is only punctuations
         | 
| 94 | 
            +
                    if len(matches) == 1 and matches[0].group() == text:
         | 
| 95 | 
            +
                        return [], [_PUNC_IDX(text, PuncPosition.ALONE)]
         | 
| 96 | 
            +
                    # build a punctuation map to be used later to restore punctuations
         | 
| 97 | 
            +
                    puncs = []
         | 
| 98 | 
            +
                    for match in matches:
         | 
| 99 | 
            +
                        position = PuncPosition.MIDDLE
         | 
| 100 | 
            +
                        if match == matches[0] and text.startswith(match.group()):
         | 
| 101 | 
            +
                            position = PuncPosition.BEGIN
         | 
| 102 | 
            +
                        elif match == matches[-1] and text.endswith(match.group()):
         | 
| 103 | 
            +
                            position = PuncPosition.END
         | 
| 104 | 
            +
                        puncs.append(_PUNC_IDX(match.group(), position))
         | 
| 105 | 
            +
                    # convert str text to a List[str], each item is separated by a punctuation
         | 
| 106 | 
            +
                    splitted_text = []
         | 
| 107 | 
            +
                    for idx, punc in enumerate(puncs):
         | 
| 108 | 
            +
                        split = text.split(punc.punc)
         | 
| 109 | 
            +
                        prefix, suffix = split[0], punc.punc.join(split[1:])
         | 
| 110 | 
            +
                        splitted_text.append(prefix)
         | 
| 111 | 
            +
                        # if the text does not end with a punctuation, add it to the last item
         | 
| 112 | 
            +
                        if idx == len(puncs) - 1 and len(suffix) > 0:
         | 
| 113 | 
            +
                            splitted_text.append(suffix)
         | 
| 114 | 
            +
                        text = suffix
         | 
| 115 | 
            +
                    return splitted_text, puncs
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                @classmethod
         | 
| 118 | 
            +
                def restore(cls, text, puncs):
         | 
| 119 | 
            +
                    """Restore punctuation in a text.
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                    Args:
         | 
| 122 | 
            +
                        text (str): The text to be processed.
         | 
| 123 | 
            +
                        puncs (List[str]): The list of punctuations map to be used for restoring.
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                    Examples ::
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                        ['This is', 'example'], ['.', '!'] -> "This is. example!"
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                    """
         | 
| 130 | 
            +
                    return cls._restore(text, puncs, 0)
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                @classmethod
         | 
| 133 | 
            +
                def _restore(cls, text, puncs, num):  # pylint: disable=too-many-return-statements
         | 
| 134 | 
            +
                    """Auxiliary method for Punctuation.restore()"""
         | 
| 135 | 
            +
                    if not puncs:
         | 
| 136 | 
            +
                        return text
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                    # nothing have been phonemized, returns the puncs alone
         | 
| 139 | 
            +
                    if not text:
         | 
| 140 | 
            +
                        return ["".join(m.punc for m in puncs)]
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                    current = puncs[0]
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                    if current.position == PuncPosition.BEGIN:
         | 
| 145 | 
            +
                        return cls._restore([current.punc + text[0]] + text[1:], puncs[1:], num)
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                    if current.position == PuncPosition.END:
         | 
| 148 | 
            +
                        return [text[0] + current.punc] + cls._restore(text[1:], puncs[1:], num + 1)
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                    if current.position == PuncPosition.ALONE:
         | 
| 151 | 
            +
                        return [current.mark] + cls._restore(text, puncs[1:], num + 1)
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                    # POSITION == MIDDLE
         | 
| 154 | 
            +
                    if len(text) == 1:  # pragma: nocover
         | 
| 155 | 
            +
                        # a corner case where the final part of an intermediate
         | 
| 156 | 
            +
                        # mark (I) has not been phonemized
         | 
| 157 | 
            +
                        return cls._restore([text[0] + current.punc], puncs[1:], num)
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                    return cls._restore([text[0] + current.punc + text[1]] + text[2:], puncs[1:], num)
         | 
| 160 | 
            +
             | 
| 161 | 
            +
             | 
| 162 | 
            +
            # if __name__ == "__main__":
         | 
| 163 | 
            +
            #     punc = Punctuation()
         | 
| 164 | 
            +
            #     text = "This is. This is, example!"
         | 
| 165 | 
            +
             | 
| 166 | 
            +
            #     print(punc.strip(text))
         | 
| 167 | 
            +
             | 
| 168 | 
            +
            #     split_text, puncs = punc.strip_to_restore(text)
         | 
| 169 | 
            +
            #     print(split_text, " ---- ", puncs)
         | 
| 170 | 
            +
             | 
| 171 | 
            +
            #     restored_text = punc.restore(split_text, puncs)
         | 
| 172 | 
            +
            #     print(restored_text)
         | 
 
			
