|
|
import subprocess |
|
|
import tempfile |
|
|
import os |
|
|
import json |
|
|
import shutil |
|
|
import time |
|
|
import librosa |
|
|
import torch |
|
|
import argparse |
|
|
import soundfile as sf |
|
|
from pathlib import Path |
|
|
import cn2an |
|
|
import requests |
|
|
import re |
|
|
import numpy as np |
|
|
import onnxruntime as ort |
|
|
import axengine as axe |
|
|
|
|
|
|
|
|
from model import SinusoidalPositionEncoder |
|
|
from utils.ax_model_bin import AX_SenseVoiceSmall |
|
|
from utils.ax_vad_bin import AX_Fsmn_vad |
|
|
from utils.vad_utils import merge_vad |
|
|
from funasr.tokenizer.sentencepiece_tokenizer import SentencepiecesTokenizer |
|
|
|
|
|
|
|
|
from libmelotts.python.split_utils import split_sentence |
|
|
from libmelotts.python.text import cleaned_text_to_sequence |
|
|
from libmelotts.python.text.cleaner import clean_text |
|
|
from libmelotts.python.symbols import LANG_TO_SYMBOL_MAP |
|
|
|
|
|
|
|
|
|
|
|
TTS_MODEL_DIR = "libmelotts/models" |
|
|
TTS_MODEL_FILES = { |
|
|
"g": "g-zh_mix_en.bin", |
|
|
"encoder": "encoder-zh.onnx", |
|
|
"decoder": "decoder-zh.axmodel" |
|
|
} |
|
|
|
|
|
|
|
|
QWEN_API_URL = "" |
|
|
|
|
|
|
|
|
|
|
|
def intersperse(lst, item): |
|
|
result = [item] * (len(lst) * 2 + 1) |
|
|
result[1::2] = lst |
|
|
return result |
|
|
|
|
|
""" |
|
|
def get_text_for_tts_infer(text, language_str, symbol_to_id=None): |
|
|
norm_text, phone, tone, word2ph = clean_text(text, language_str) |
|
|
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str, symbol_to_id) |
|
|
|
|
|
phone = intersperse(phone, 0) |
|
|
tone = intersperse(tone, 0) |
|
|
language = intersperse(language, 0) |
|
|
|
|
|
phone = np.array(phone, dtype=np.int32) |
|
|
tone = np.array(tone, dtype=np.int32) |
|
|
language = np.array(language, dtype=np.int32) |
|
|
word2ph = np.array(word2ph, dtype=np.int32) * 2 |
|
|
word2ph[0] += 1 |
|
|
|
|
|
return phone, tone, language, norm_text, word2ph |
|
|
""" |
|
|
|
|
|
|
|
|
def get_text_for_tts_infer(text, language_str, symbol_to_id=None): |
|
|
"""修复版音素处理:确保所有数组长度一致""" |
|
|
try: |
|
|
norm_text, phone, tone, word2ph = clean_text(text, language_str) |
|
|
|
|
|
|
|
|
phone_mapping = { |
|
|
'ɛ': '', 'æ': '', 'ʌ': '', 'ʊ': '', 'ɔ': '', 'ɪ': '', 'ɝ': '', 'ɚ': '', 'ɑ': '', |
|
|
'ʒ': '', 'θ': '', 'ð': '', 'ŋ': '', 'ʃ': '', 'ʧ': '', 'ʤ': '', 'ː': '', 'ˈ': '', |
|
|
'ˌ': '', 'ʰ': '', 'ʲ': '', 'ʷ': '', 'ʔ': '', 'ɾ': '', 'ɹ': '', 'ɫ': '', 'ɡ': '', |
|
|
} |
|
|
|
|
|
|
|
|
processed_phone = [] |
|
|
processed_tone = [] |
|
|
removed_symbols = set() |
|
|
|
|
|
for p, t in zip(phone, tone): |
|
|
if p in phone_mapping: |
|
|
|
|
|
removed_symbols.add(p) |
|
|
elif p in symbol_to_id: |
|
|
|
|
|
processed_phone.append(p) |
|
|
processed_tone.append(t) |
|
|
else: |
|
|
|
|
|
removed_symbols.add(p) |
|
|
|
|
|
|
|
|
if removed_symbols: |
|
|
print(f"[音素过滤] 删除了 {len(removed_symbols)} 个特殊音素: {sorted(removed_symbols)}") |
|
|
print(f"[音素过滤] 处理后音素序列长度: {len(processed_phone)}") |
|
|
print(f"[音素过滤] 处理后音调序列长度: {len(processed_tone)}") |
|
|
|
|
|
|
|
|
if not processed_phone: |
|
|
print("[警告] 没有有效音素,使用默认中文音素") |
|
|
processed_phone = ['ni', 'hao'] |
|
|
processed_tone = ['1', '3'] |
|
|
word2ph = [1, 1] |
|
|
|
|
|
|
|
|
if len(processed_phone) != len(phone): |
|
|
print(f"[警告] 音素序列长度变化: {len(phone)} -> {len(processed_phone)}") |
|
|
|
|
|
word2ph = [1] * len(processed_phone) |
|
|
|
|
|
phone, tone, language = cleaned_text_to_sequence(processed_phone, processed_tone, language_str, symbol_to_id) |
|
|
|
|
|
phone = intersperse(phone, 0) |
|
|
tone = intersperse(tone, 0) |
|
|
language = intersperse(language, 0) |
|
|
|
|
|
phone = np.array(phone, dtype=np.int32) |
|
|
tone = np.array(tone, dtype=np.int32) |
|
|
language = np.array(language, dtype=np.int32) |
|
|
word2ph = np.array(word2ph, dtype=np.int32) * 2 |
|
|
word2ph[0] += 1 |
|
|
return phone, tone, language, norm_text, word2ph |
|
|
|
|
|
except Exception as e: |
|
|
print(f"[错误] 文本处理失败: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
raise e |
|
|
|
|
|
|
|
|
def audio_numpy_concat(segment_data_list, sr, speed=1.): |
|
|
audio_segments = [] |
|
|
for segment_data in segment_data_list: |
|
|
audio_segments += segment_data.reshape(-1).tolist() |
|
|
audio_segments += [0] * int((sr * 0.05) / speed) |
|
|
audio_segments = np.array(audio_segments).astype(np.float32) |
|
|
return audio_segments |
|
|
|
|
|
|
|
|
def merge_sub_audio(sub_audio_list, pad_size, audio_len): |
|
|
|
|
|
if pad_size > 0: |
|
|
for i in range(len(sub_audio_list) - 1): |
|
|
sub_audio_list[i][-pad_size:] += sub_audio_list[i+1][:pad_size] |
|
|
sub_audio_list[i][-pad_size:] /= 2 |
|
|
if i > 0: |
|
|
sub_audio_list[i] = sub_audio_list[i][pad_size:] |
|
|
|
|
|
sub_audio = np.concatenate(sub_audio_list, axis=-1) |
|
|
return sub_audio[:audio_len] |
|
|
|
|
|
|
|
|
def calc_word2pronoun(word2ph, pronoun_lens): |
|
|
indice = [0] |
|
|
for ph in word2ph[:-1]: |
|
|
indice.append(indice[-1] + ph) |
|
|
word2pronoun = [] |
|
|
for i, ph in zip(indice, word2ph): |
|
|
word2pronoun.append(np.sum(pronoun_lens[i : i + ph])) |
|
|
return word2pronoun |
|
|
|
|
|
|
|
|
def generate_slices(word2pronoun, dec_len): |
|
|
pn_start, pn_end = 0, 0 |
|
|
zp_start, zp_end = 0, 0 |
|
|
zp_len = 0 |
|
|
pn_slices = [] |
|
|
zp_slices = [] |
|
|
while pn_end < len(word2pronoun): |
|
|
|
|
|
if pn_end - pn_start > 2 and np.sum(word2pronoun[pn_end - 2 : pn_end + 1]) <= dec_len: |
|
|
zp_len = np.sum(word2pronoun[pn_end - 2 : pn_end]) |
|
|
zp_start = zp_end - zp_len |
|
|
pn_start = pn_end - 2 |
|
|
else: |
|
|
zp_len = 0 |
|
|
zp_start = zp_end |
|
|
pn_start = pn_end |
|
|
|
|
|
while pn_end < len(word2pronoun) and zp_len + word2pronoun[pn_end] <= dec_len: |
|
|
zp_len += word2pronoun[pn_end] |
|
|
pn_end += 1 |
|
|
zp_end = zp_start + zp_len |
|
|
pn_slices.append(slice(pn_start, pn_end)) |
|
|
zp_slices.append(slice(zp_start, zp_end)) |
|
|
return pn_slices, zp_slices |
|
|
|
|
|
|
|
|
|
|
|
def lang_detect_with_regex(text): |
|
|
""" |
|
|
语言识别 |
|
|
""" |
|
|
|
|
|
text_without_digits = re.sub(r'\d+', '', text) |
|
|
|
|
|
if not text_without_digits: |
|
|
return 'unknown' |
|
|
|
|
|
|
|
|
if re.search(r'[\u4e00-\u9fff]', text_without_digits): |
|
|
return 'chinese' |
|
|
else: |
|
|
|
|
|
if re.search(r'[a-zA-Z]', text_without_digits): |
|
|
return 'english' |
|
|
else: |
|
|
return 'unknown' |
|
|
|
|
|
class QwenTranslationAPI: |
|
|
def __init__(self, api_url=QWEN_API_URL): |
|
|
self.api_url = api_url |
|
|
self.session_id = f"speech_translate_{int(time.time())}" |
|
|
|
|
|
def translate(self, text_content, max_retries=3, timeout=120): |
|
|
"""调用千问API进行翻译""" |
|
|
if not text_content or text_content.strip() == "": |
|
|
return "输入文本为空" |
|
|
|
|
|
|
|
|
if lang_detect_with_regex(text_content)=='chinese': |
|
|
prompt_f = "翻译成英文" |
|
|
else: |
|
|
prompt_f= "翻译成中文" |
|
|
|
|
|
prompt = f"{prompt_f}:{text_content}" |
|
|
print(f"[翻译API] 发送请求: {prompt}") |
|
|
|
|
|
for attempt in range(max_retries): |
|
|
try: |
|
|
|
|
|
generate_url = f"{self.api_url}/api/generate" |
|
|
payload = { |
|
|
"prompt": prompt, |
|
|
"temperature": 0.1, |
|
|
"repetition_penalty": 1.0, |
|
|
"top-p": 0.9, |
|
|
"top-k": 40, |
|
|
"max_new_tokens": 512 |
|
|
} |
|
|
|
|
|
print(f"[翻译API] 开始生成请求 (尝试 {attempt + 1}/{max_retries})") |
|
|
response = requests.post(generate_url, json=payload, timeout=30) |
|
|
response.raise_for_status() |
|
|
print("[翻译API] 生成请求成功") |
|
|
|
|
|
|
|
|
result_url = f"{self.api_url}/api/generate_provider" |
|
|
start_time = time.time() |
|
|
full_translation = "" |
|
|
last_chunk = "" |
|
|
|
|
|
while time.time() - start_time < timeout: |
|
|
try: |
|
|
result_response = requests.get(result_url, timeout=10) |
|
|
result_data = result_response.json() |
|
|
|
|
|
|
|
|
current_chunk = result_data.get("response", "") |
|
|
full_translation += current_chunk |
|
|
|
|
|
|
|
|
if result_data.get("done", False): |
|
|
|
|
|
print(f"[翻译API] 翻译完成: {full_translation}") |
|
|
return full_translation |
|
|
|
|
|
time.sleep(0.05) |
|
|
|
|
|
except requests.exceptions.RequestException as e: |
|
|
print(f"[翻译API] 轮询请求失败: {e}") |
|
|
if time.time() - start_time > timeout: |
|
|
break |
|
|
continue |
|
|
|
|
|
print(f"[翻译API] 轮询超时,尝试第 {attempt + 1} 次重试") |
|
|
|
|
|
except requests.exceptions.RequestException as e: |
|
|
print(f"[翻译API] 请求失败 (尝试 {attempt + 1}/{max_retries}): {e}") |
|
|
if attempt < max_retries - 1: |
|
|
wait_time = 2 ** attempt |
|
|
print(f"[翻译API] 等待 {wait_time} 秒后重试...") |
|
|
time.sleep(wait_time) |
|
|
else: |
|
|
return f"翻译失败: {str(e)}" |
|
|
except Exception as e: |
|
|
print(f"[翻译API] 翻译过程出错: {e}") |
|
|
return f"翻译失败: {str(e)}" |
|
|
|
|
|
return "翻译超时,请检查API服务状态" |
|
|
|
|
|
class SpeechTranslationPipeline: |
|
|
def __init__(self, |
|
|
tts_model_dir, tts_model_files, |
|
|
asr_model_dir="ax_model", seq_len=132, |
|
|
tts_dec_len=128, sample_rate=44100, tts_speed=0.8, |
|
|
qwen_api_url=QWEN_API_URL): |
|
|
self.tts_model_dir = tts_model_dir |
|
|
self.tts_model_files = tts_model_files |
|
|
self.asr_model_dir = asr_model_dir |
|
|
self.seq_len = seq_len |
|
|
self.tts_dec_len = tts_dec_len |
|
|
self.sample_rate = sample_rate |
|
|
self.tts_speed = tts_speed |
|
|
self.qwen_api_url = qwen_api_url |
|
|
|
|
|
|
|
|
self._init_asr_models() |
|
|
|
|
|
|
|
|
self._init_tts_models() |
|
|
|
|
|
|
|
|
self.translator = QwenTranslationAPI(api_url=qwen_api_url) |
|
|
|
|
|
|
|
|
self._validate_files() |
|
|
|
|
|
def _init_asr_models(self): |
|
|
"""初始化语音识别相关模型""" |
|
|
print("Initializing SenseVoice models...") |
|
|
|
|
|
|
|
|
self.model_vad = AX_Fsmn_vad(self.asr_model_dir) |
|
|
|
|
|
|
|
|
self.embed = SinusoidalPositionEncoder() |
|
|
self.position_encoding = self.embed.get_position_encoding( |
|
|
torch.randn(1, self.seq_len, 560)).numpy() |
|
|
|
|
|
|
|
|
self.model_bin = AX_SenseVoiceSmall(self.asr_model_dir, seq_len=self.seq_len) |
|
|
|
|
|
|
|
|
tokenizer_path = os.path.join(self.asr_model_dir, "chn_jpn_yue_eng_ko_spectok.bpe.model") |
|
|
self.tokenizer = SentencepiecesTokenizer(bpemodel=tokenizer_path) |
|
|
|
|
|
print("SenseVoice models initialized successfully.") |
|
|
|
|
|
def _init_tts_models(self): |
|
|
"""初始化TTS相关模型""" |
|
|
print("Initializing MeloTTS models...") |
|
|
init_start = time.time() |
|
|
|
|
|
|
|
|
enc_model = os.path.join(self.tts_model_dir, self.tts_model_files["encoder"]) |
|
|
dec_model = os.path.join(self.tts_model_dir, self.tts_model_files["decoder"]) |
|
|
|
|
|
model_load_start = time.time() |
|
|
self.sess_enc = ort.InferenceSession(enc_model, providers=["CPUExecutionProvider"], sess_options=ort.SessionOptions()) |
|
|
self.sess_dec = axe.InferenceSession(dec_model) |
|
|
print(f" Load encoder/decoder models: {(time.time() - model_load_start)*1000:.2f}ms") |
|
|
|
|
|
|
|
|
g_file = os.path.join(self.tts_model_dir, self.tts_model_files["g"]) |
|
|
self.tts_g = np.fromfile(g_file, dtype=np.float32).reshape(1, 256, 1) |
|
|
|
|
|
|
|
|
self.tts_language = "ZH_MIX_EN" |
|
|
self.symbol_to_id = {s: i for i, s in enumerate(LANG_TO_SYMBOL_MAP[self.tts_language])} |
|
|
|
|
|
|
|
|
print(" Warming up TTS modules (loading language models, tokenizers, etc.)...") |
|
|
warmup_start = time.time() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
warmup_start_mix = time.time() |
|
|
warmup_text_mix = "这是一个test测试。" |
|
|
_, _, _, _, _ = get_text_for_tts_infer(warmup_text_mix, self.tts_language, symbol_to_id=self.symbol_to_id) |
|
|
print(f" Mixed ZH-EN warm-up: {(time.time() - warmup_start_mix)*1000:.2f}ms") |
|
|
except Exception as e: |
|
|
print(f" Warning: Mixed warm-up failed: {e}") |
|
|
|
|
|
total_init_time = (time.time() - init_start) * 1000 |
|
|
print(f"MeloTTS models initialized successfully. Total init time: {total_init_time:.2f}ms ({total_init_time/1000:.2f}s)") |
|
|
|
|
|
def _validate_files(self): |
|
|
"""验证所有必需的文件都存在""" |
|
|
|
|
|
for key, filename in self.tts_model_files.items(): |
|
|
filepath = os.path.join(self.tts_model_dir, filename) |
|
|
if not os.path.exists(filepath): |
|
|
raise FileNotFoundError(f"TTS模型文件不存在: {filepath}") |
|
|
|
|
|
|
|
|
try: |
|
|
response = requests.get(f"{self.qwen_api_url}/api/generate_provider", timeout=5) |
|
|
print("[API检查] 千问API服务连接正常") |
|
|
except: |
|
|
print("[API警告] 无法连接到千问API服务,请确保已启动API服务") |
|
|
|
|
|
def speech_recognition(self, speech, fs): |
|
|
""" |
|
|
第一步:语音识别(ASR) |
|
|
""" |
|
|
speech_lengths = len(speech) |
|
|
|
|
|
|
|
|
print("Running VAD...") |
|
|
vad_start_time = time.time() |
|
|
res_vad = self.model_vad(speech)[0] |
|
|
vad_segments = merge_vad(res_vad, 15 * 1000) |
|
|
vad_time_cost = time.time() - vad_start_time |
|
|
print(f"VAD processing time: {vad_time_cost:.2f} seconds") |
|
|
print(f"VAD segments detected: {len(vad_segments)}") |
|
|
|
|
|
|
|
|
print("Running ASR...") |
|
|
asr_start_time = time.time() |
|
|
all_results = "" |
|
|
|
|
|
|
|
|
for i, segment in enumerate(vad_segments): |
|
|
segment_start, segment_end = segment |
|
|
start_sample = int(segment_start / 1000 * fs) |
|
|
end_sample = min(int(segment_end / 1000 * fs), speech_lengths) |
|
|
segment_speech = speech[start_sample:end_sample] |
|
|
|
|
|
|
|
|
segment_filename = f"temp_segment_{i}.wav" |
|
|
sf.write(segment_filename, segment_speech, fs) |
|
|
|
|
|
|
|
|
try: |
|
|
segment_res = self.model_bin( |
|
|
segment_filename, |
|
|
"auto", |
|
|
True, |
|
|
self.position_encoding, |
|
|
tokenizer=self.tokenizer, |
|
|
) |
|
|
|
|
|
all_results += segment_res |
|
|
|
|
|
|
|
|
if os.path.exists(segment_filename): |
|
|
os.remove(segment_filename) |
|
|
|
|
|
except Exception as e: |
|
|
if os.path.exists(segment_filename): |
|
|
os.remove(segment_filename) |
|
|
print(f"Error processing segment {i}: {e}") |
|
|
continue |
|
|
|
|
|
asr_time_cost = time.time() - asr_start_time |
|
|
print(f"ASR processing time: {asr_time_cost:.2f} seconds") |
|
|
print(f"ASR Result: {all_results}") |
|
|
|
|
|
return all_results.strip() |
|
|
|
|
|
def run_translation(self, text_content): |
|
|
""" |
|
|
第二步:调用Qwen大模型API中英互译 |
|
|
""" |
|
|
print("Starting translation via API...") |
|
|
translation_start_time = time.time() |
|
|
|
|
|
|
|
|
translate_content = self.translator.translate(text_content) |
|
|
|
|
|
translation_time_cost = time.time() - translation_start_time |
|
|
print(f"Translation processing time: {translation_time_cost:.2f} seconds") |
|
|
print(f"Translation Result: {translate_content}") |
|
|
|
|
|
return translate_content |
|
|
|
|
|
def run_tts(self, translate_content, output_dir, output_wav=None): |
|
|
""" |
|
|
第三步:使用TTS模型合成语音 |
|
|
""" |
|
|
output_path = os.path.join(output_dir, output_wav) |
|
|
|
|
|
try: |
|
|
|
|
|
if lang_detect_with_regex(translate_content) == "chinese": |
|
|
translate_content = cn2an.transform(translate_content, "an2cn") |
|
|
|
|
|
print(f"TTS synthesis for text: {translate_content}") |
|
|
|
|
|
|
|
|
sens = split_sentence(translate_content, language_str=self.tts_language) |
|
|
print(f"Text split into {len(sens)} sentences") |
|
|
|
|
|
|
|
|
audio_list = [] |
|
|
|
|
|
|
|
|
for n, se in enumerate(sens): |
|
|
|
|
|
if self.tts_language in ['EN', 'ZH_MIX_EN']: |
|
|
se = re.sub(r'([a-z])([A-Z])', r'\1 \2', se) |
|
|
|
|
|
print(f"Processing sentence[{n}]: {se}") |
|
|
|
|
|
|
|
|
phones, tones, lang_ids, norm_text, word2ph = get_text_for_tts_infer( |
|
|
se, self.tts_language, symbol_to_id=self.symbol_to_id) |
|
|
|
|
|
|
|
|
encoder_start = time.time() |
|
|
z_p, pronoun_lens, audio_len = self.sess_enc.run(None, input_feed={ |
|
|
'phone': phones, 'g': self.tts_g, |
|
|
'tone': tones, 'language': lang_ids, |
|
|
'noise_scale': np.array([0], dtype=np.float32), |
|
|
'length_scale': np.array([1.0 / self.tts_speed], dtype=np.float32), |
|
|
'noise_scale_w': np.array([0], dtype=np.float32), |
|
|
'sdp_ratio': np.array([0], dtype=np.float32)}) |
|
|
print(f"Encoder run time: {1000 * (time.time() - encoder_start):.2f}ms") |
|
|
|
|
|
|
|
|
word2pronoun = calc_word2pronoun(word2ph, pronoun_lens) |
|
|
|
|
|
pn_slices, zp_slices = generate_slices(word2pronoun, self.tts_dec_len) |
|
|
|
|
|
audio_len = audio_len[0] |
|
|
sub_audio_list = [] |
|
|
|
|
|
for i, (ps, zs) in enumerate(zip(pn_slices, zp_slices)): |
|
|
zp_slice = z_p[..., zs] |
|
|
|
|
|
|
|
|
sub_dec_len = zp_slice.shape[-1] |
|
|
|
|
|
sub_audio_len = 512 * sub_dec_len |
|
|
|
|
|
|
|
|
if zp_slice.shape[-1] < self.tts_dec_len: |
|
|
zp_slice = np.concatenate((zp_slice, np.zeros((*zp_slice.shape[:-1], self.tts_dec_len - zp_slice.shape[-1]), dtype=np.float32)), axis=-1) |
|
|
|
|
|
decoder_start = time.time() |
|
|
audio = self.sess_dec.run(None, input_feed={"z_p": zp_slice, "g": self.tts_g})[0].flatten() |
|
|
|
|
|
|
|
|
audio_start = 0 |
|
|
if len(sub_audio_list) > 0: |
|
|
if pn_slices[i - 1].stop > ps.start: |
|
|
|
|
|
audio_start = 512 * word2pronoun[ps.start] |
|
|
|
|
|
audio_end = sub_audio_len |
|
|
if i < len(pn_slices) - 1: |
|
|
if ps.stop > pn_slices[i + 1].start: |
|
|
|
|
|
audio_end = sub_audio_len - 512 * word2pronoun[ps.stop - 1] |
|
|
|
|
|
audio = audio[audio_start:audio_end] |
|
|
print(f"Decode slice[{i}]: decoder run time {1000 * (time.time() - decoder_start):.2f}ms") |
|
|
sub_audio_list.append(audio) |
|
|
|
|
|
|
|
|
sub_audio = merge_sub_audio(sub_audio_list, 0, audio_len) |
|
|
audio_list.append(sub_audio) |
|
|
|
|
|
|
|
|
audio = audio_numpy_concat(audio_list, sr=self.sample_rate, speed=self.tts_speed) |
|
|
|
|
|
|
|
|
sf.write(output_path, audio, self.sample_rate) |
|
|
print(f"TTS audio saved to {output_path}") |
|
|
|
|
|
return output_path |
|
|
|
|
|
except Exception as e: |
|
|
print(f"TTS synthesis failed: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
raise e |
|
|
|
|
|
def full_pipeline(self, speech, fs, output_dir=None, output_tts=None): |
|
|
""" |
|
|
完整Pipeline:语音识别 -> 翻译 -> TTS合成 |
|
|
""" |
|
|
|
|
|
|
|
|
print("\n----------------------VAD+ASR----------------------------\n") |
|
|
start_time = time.time() |
|
|
text_content = self.speech_recognition(speech, fs) |
|
|
asr_time = time.time() - start_time |
|
|
print(f"语音识别耗时: {asr_time:.2f} 秒") |
|
|
|
|
|
if not text_content or text_content.strip() == "": |
|
|
raise ValueError("ASR未能识别出有效文本") |
|
|
|
|
|
|
|
|
print("\n---------------------Qwen翻译---------------------------\n") |
|
|
start_time = time.time() |
|
|
translate_content = self.run_translation(text_content) |
|
|
translate_time = time.time() - start_time |
|
|
print(f"翻译耗时: {translate_time:.2f} 秒") |
|
|
|
|
|
|
|
|
print("-------------------------TTS-------------------------------\n") |
|
|
start_time = time.time() |
|
|
output_path = self.run_tts(translate_content, output_dir, output_tts) |
|
|
tts_time = time.time() - start_time |
|
|
print(f"TTS合成耗时: {tts_time:.2f} 秒") |
|
|
|
|
|
return { |
|
|
"original_text": text_content, |
|
|
"translated_text": translate_content, |
|
|
"audio_path": output_path |
|
|
} |
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Speech Recognition, Translation and TTS Pipeline") |
|
|
parser.add_argument("--audio_file", type=str, default="./wav/en.mp3", help="Input audio file path") |
|
|
parser.add_argument("--output_dir", type=str, default="./output", help="Output directory") |
|
|
parser.add_argument("--output_tts", type=str, default="output.wav", help="Output TTS file name") |
|
|
parser.add_argument("--api_url", type=str, default="http://10.126.29.158:8000", help="Qwen API server URL") |
|
|
|
|
|
args = parser.parse_args() |
|
|
print("-------------------START------------------------\n") |
|
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
|
|
|
print(f"Processing audio file: {args.audio_file}") |
|
|
|
|
|
speech, fs = librosa.load(args.audio_file, sr=None) |
|
|
if fs != 16000: |
|
|
print(f"Resampling audio from {fs}Hz to 16000Hz") |
|
|
speech = librosa.resample(y=speech, orig_sr=fs, target_sr=16000) |
|
|
fs = 16000 |
|
|
audio_duration = librosa.get_duration(y=speech, sr=fs) |
|
|
|
|
|
|
|
|
pipeline = SpeechTranslationPipeline( |
|
|
tts_model_dir=TTS_MODEL_DIR, |
|
|
tts_model_files=TTS_MODEL_FILES, |
|
|
asr_model_dir="ax_model", |
|
|
seq_len=132, |
|
|
tts_dec_len=128, |
|
|
sample_rate=44100, |
|
|
tts_speed=0.8, |
|
|
qwen_api_url=args.api_url |
|
|
) |
|
|
|
|
|
start_time = time.time() |
|
|
try: |
|
|
|
|
|
result = pipeline.full_pipeline(speech, fs, args.output_dir, args.output_tts) |
|
|
|
|
|
print("\n" + "="*50) |
|
|
print("speech translate 完成!") |
|
|
print("="*50 + "\n") |
|
|
print(f"原始音频: {args.audio_file}") |
|
|
print(f"原始文本: {result['original_text']}") |
|
|
print(f"翻译文本: {result['translated_text']}") |
|
|
print(f"生成音频: {result['audio_path']}") |
|
|
|
|
|
|
|
|
result_file = os.path.join(args.output_dir, "pipeline_result.txt") |
|
|
with open(result_file, 'w', encoding='utf-8') as f: |
|
|
f.write(f"原始音频: {args.audio_file}\n") |
|
|
f.write(f"识别文本: {result['original_text']}\n") |
|
|
f.write(f"翻译结果: {result['translated_text']}\n") |
|
|
f.write(f"合成音频: {result['audio_path']}\n") |
|
|
|
|
|
time_cost = time.time() - start_time |
|
|
rtf = time_cost / audio_duration |
|
|
print(f"Inference time for {args.audio_file}: {time_cost:.2f} seconds") |
|
|
print(f"Audio duration: {audio_duration:.2f} seconds") |
|
|
print(f"RTF: {rtf:.2f}\n") |
|
|
except Exception as e: |
|
|
print(f"Pipeline执行失败: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |